Spaces:
Sleeping
Sleeping
from transformers import T5Tokenizer,T5ForConditionalGeneration | |
import torch | |
import lightning as L | |
import numpy as np | |
import random | |
import gradio as gr | |
MODEL_NAME:str = "google/flan-t5-small" | |
def load_tokenizer(tokenizer_path:str): | |
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path,local_files_only=True) | |
return tokenizer | |
def qa_preprocess_data(context:str, tokenizer:T5Tokenizer): | |
input_prefix:str = "Generate relevant question and answer for this paragraph:\n " | |
inputs = input_prefix + context | |
model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt") | |
return model_inputs | |
def distractor_preprocess_data(context:str,question:str, | |
answer:str,tokenizer:T5Tokenizer): | |
input_prefix:str = "Generate 3 plausible but incorrect answer options (distractors) for the given question and correct answer, based on the provided context:" | |
inputs = f"{input_prefix}\nCONTEXT:\n{context}\nQUESTION: {question}\nANSWER: {answer}" | |
model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt") | |
return model_inputs | |
class DistractorTrained(L.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) | |
def forward(self,input_ids,attention_mask): | |
return self.model.generate(input_ids=input_ids, attention_mask=attention_mask, | |
num_beams=4,max_new_tokens=80, | |
do_sample=True,temperature=1.2) | |
class QATrained(L.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME) | |
def forward(self,input_ids:torch.Tensor,attention_mask:torch.Tensor, | |
num_beams:int=4,max_new_tokens:int=65, | |
temperature:float=1.2): | |
return self.model.generate( | |
input_ids=input_ids,attention_mask=attention_mask, | |
num_beams=num_beams,max_new_tokens=65, | |
do_sample=True,temperature=temperature | |
) | |
def load_qa_model(model_path:str): | |
model = QATrained.load_from_checkpoint(model_path) | |
return model | |
def load_distractor_model(model_path:str): | |
model = DistractorTrained.load_from_checkpoint(model_path) | |
return model | |
def predict_qa(model:QATrained,tokenizer:T5Tokenizer,model_inputs:torch.Tensor, | |
device:str="cpu"): | |
model.to(device) | |
model.eval() | |
with torch.inference_mode(): | |
generated_ids = model(input_ids=model_inputs["input_ids"].to(device), | |
attention_mask = model_inputs["attention_mask"].to(device)) | |
generated_ids = generated_ids.cpu() | |
decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids] | |
return decoded_predictions | |
def predict_distractor(model:DistractorTrained,tokenizer:T5Tokenizer, | |
model_inputs:torch.Tensor,device:str="cpu"): | |
model.to(device) | |
model.eval() | |
with torch.inference_mode(): | |
generated_ids = model(input_ids=model_inputs["input_ids"].to(device), | |
attention_mask = model_inputs["attention_mask"].to(device)) | |
generated_ids = generated_ids.cpu() | |
decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids] | |
return decoded_predictions | |
def main(user_input): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer_path:str = "./t5_tokenizer" | |
qa_model_path:str = "./qa-t5-small.ckpt" | |
distractor_model_path:str = "./distractor_t5-small.ckpt" | |
tokenizer = load_tokenizer(tokenizer_path) | |
qa_model = load_qa_model(qa_model_path) | |
distractor_model = load_distractor_model(distractor_model_path) | |
qa_model_inputs = qa_preprocess_data(user_input,tokenizer) | |
qa_decoded_predictions = predict_qa(qa_model,tokenizer,qa_model_inputs,device=device) | |
qa_decoded_predictions = qa_decoded_predictions[0] | |
indices = [] | |
start = 0 | |
while True: | |
index = qa_decoded_predictions.find("[ANSWER] ",start) | |
if index==-1: | |
break | |
indices.append(index) | |
start = index + 1 | |
question = qa_decoded_predictions[11:indices[0]].rstrip() | |
if len(indices)==1: | |
answer = qa_decoded_predictions[indices[0]+9:].rstrip() | |
if len(indices)>1: | |
answer = qa_decoded_predictions[indices[0]+9:indices[1]-1].rstrip() | |
filtered_ans = answer.replace("?",".") | |
distractor_model_inputs = distractor_preprocess_data(user_input,question,filtered_ans,tokenizer) | |
distractor_decoded_predictions = predict_distractor(distractor_model,tokenizer,distractor_model_inputs,device=device) | |
distractor_decoded_predictions = distractor_decoded_predictions[0] | |
option_strings = ["[OPTION 1]","[OPTION 2]","[OPTION 3]"] | |
option_indices:list[int] = [] | |
for option in option_strings: | |
ind:int = distractor_decoded_predictions.find(option) | |
option_indices.append(ind) | |
for option in option_strings: | |
option1:str = distractor_decoded_predictions[11:option_indices[1]].replace(option,"").strip() | |
option2:str = distractor_decoded_predictions[option_indices[1]+10:option_indices[-1]].replace(option,"").strip() | |
option3:str = distractor_decoded_predictions[option_indices[1]+10:].replace(option,"").strip() | |
option4:str = answer | |
return {"question": question, | |
"option1": option1, | |
"option2": option2, | |
"option3": option3, | |
"option4": option4} | |
def shuffle_options(question_data): | |
options = [ | |
question_data["option1"], | |
question_data["option2"], | |
question_data["option3"], | |
question_data["option4"] | |
] | |
correct_answer = question_data["option4"] | |
random.shuffle(options) | |
return options, correct_answer | |
def process_input(context): | |
question_data = main(context) | |
options, correct_answer = shuffle_options(question_data) | |
return question_data["question"], options, correct_answer | |
def check_answer(choice, correct_answer): | |
if choice == correct_answer: | |
return f'<p style="color: #28a745;">Correct!</p>' | |
else: | |
return f'<p style="color: #dc3545;">Incorrect ! Try again.</p>' | |
with gr.Blocks() as demo: | |
gr.Markdown("# MCQ Generator") | |
with gr.Row(): | |
context_input = gr.Textbox(label="Context Paragraph", lines=5) | |
generate_button = gr.Button("Generate Question") | |
question_output = gr.Textbox(label="Question") | |
options_radio = gr.Radio(label="Options", choices=[]) | |
submit_button = gr.Button("Submit Answer") | |
result_output = gr.HTML() | |
correct_answer = gr.State() | |
def update_interface(question, options, correct): | |
return { | |
question_output: question, | |
options_radio: gr.Radio(choices=options, label="Options"), | |
correct_answer: correct | |
} | |
generate_button.click( | |
process_input, | |
inputs=[context_input], | |
outputs=[question_output, options_radio, correct_answer] | |
).then( | |
update_interface, | |
inputs=[question_output, options_radio, correct_answer], | |
outputs=[question_output, options_radio, correct_answer] | |
) | |
submit_button.click( | |
check_answer, | |
inputs=[options_radio, correct_answer], | |
outputs=[result_output] | |
) | |
if __name__=="__main__": | |
demo.launch(debug=True) | |