mcq-generation / app.py
pujanpaudel's picture
Rename inference.py to app.py
21042da verified
raw
history blame
7.46 kB
from transformers import T5Tokenizer,T5ForConditionalGeneration
import torch
import lightning as L
import numpy as np
import random
import gradio as gr
MODEL_NAME:str = "google/flan-t5-small"
def load_tokenizer(tokenizer_path:str):
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path,local_files_only=True)
return tokenizer
def qa_preprocess_data(context:str, tokenizer:T5Tokenizer):
input_prefix:str = "Generate relevant question and answer for this paragraph:\n "
inputs = input_prefix + context
model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt")
return model_inputs
def distractor_preprocess_data(context:str,question:str,
answer:str,tokenizer:T5Tokenizer):
input_prefix:str = "Generate 3 plausible but incorrect answer options (distractors) for the given question and correct answer, based on the provided context:"
inputs = f"{input_prefix}\nCONTEXT:\n{context}\nQUESTION: {question}\nANSWER: {answer}"
model_inputs:torch.Tensor = tokenizer(inputs,return_tensors="pt")
return model_inputs
class DistractorTrained(L.LightningModule):
def __init__(self):
super().__init__()
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
def forward(self,input_ids,attention_mask):
return self.model.generate(input_ids=input_ids, attention_mask=attention_mask,
num_beams=4,max_new_tokens=80,
do_sample=True,temperature=1.2)
class QATrained(L.LightningModule):
def __init__(self):
super().__init__()
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
def forward(self,input_ids:torch.Tensor,attention_mask:torch.Tensor,
num_beams:int=4,max_new_tokens:int=65,
temperature:float=1.2):
return self.model.generate(
input_ids=input_ids,attention_mask=attention_mask,
num_beams=num_beams,max_new_tokens=65,
do_sample=True,temperature=temperature
)
def load_qa_model(model_path:str):
model = QATrained.load_from_checkpoint(model_path)
return model
def load_distractor_model(model_path:str):
model = DistractorTrained.load_from_checkpoint(model_path)
return model
def predict_qa(model:QATrained,tokenizer:T5Tokenizer,model_inputs:torch.Tensor,
device:str="cpu"):
model.to(device)
model.eval()
with torch.inference_mode():
generated_ids = model(input_ids=model_inputs["input_ids"].to(device),
attention_mask = model_inputs["attention_mask"].to(device))
generated_ids = generated_ids.cpu()
decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids]
return decoded_predictions
def predict_distractor(model:DistractorTrained,tokenizer:T5Tokenizer,
model_inputs:torch.Tensor,device:str="cpu"):
model.to(device)
model.eval()
with torch.inference_mode():
generated_ids = model(input_ids=model_inputs["input_ids"].to(device),
attention_mask = model_inputs["attention_mask"].to(device))
generated_ids = generated_ids.cpu()
decoded_predictions = [tokenizer.decode(ids,skip_special_tokens=True) for ids in generated_ids]
return decoded_predictions
def main(user_input):
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer_path:str = "./t5_tokenizer"
qa_model_path:str = "./qa-t5-small.ckpt"
distractor_model_path:str = "./distractor_t5-small.ckpt"
tokenizer = load_tokenizer(tokenizer_path)
qa_model = load_qa_model(qa_model_path)
distractor_model = load_distractor_model(distractor_model_path)
qa_model_inputs = qa_preprocess_data(user_input,tokenizer)
qa_decoded_predictions = predict_qa(qa_model,tokenizer,qa_model_inputs,device=device)
qa_decoded_predictions = qa_decoded_predictions[0]
indices = []
start = 0
while True:
index = qa_decoded_predictions.find("[ANSWER] ",start)
if index==-1:
break
indices.append(index)
start = index + 1
question = qa_decoded_predictions[11:indices[0]].rstrip()
if len(indices)==1:
answer = qa_decoded_predictions[indices[0]+9:].rstrip()
if len(indices)>1:
answer = qa_decoded_predictions[indices[0]+9:indices[1]-1].rstrip()
filtered_ans = answer.replace("?",".")
distractor_model_inputs = distractor_preprocess_data(user_input,question,filtered_ans,tokenizer)
distractor_decoded_predictions = predict_distractor(distractor_model,tokenizer,distractor_model_inputs,device=device)
distractor_decoded_predictions = distractor_decoded_predictions[0]
option_strings = ["[OPTION 1]","[OPTION 2]","[OPTION 3]"]
option_indices:list[int] = []
for option in option_strings:
ind:int = distractor_decoded_predictions.find(option)
option_indices.append(ind)
for option in option_strings:
option1:str = distractor_decoded_predictions[11:option_indices[1]].replace(option,"").strip()
option2:str = distractor_decoded_predictions[option_indices[1]+10:option_indices[-1]].replace(option,"").strip()
option3:str = distractor_decoded_predictions[option_indices[1]+10:].replace(option,"").strip()
option4:str = answer
return {"question": question,
"option1": option1,
"option2": option2,
"option3": option3,
"option4": option4}
def shuffle_options(question_data):
options = [
question_data["option1"],
question_data["option2"],
question_data["option3"],
question_data["option4"]
]
correct_answer = question_data["option4"]
random.shuffle(options)
return options, correct_answer
def process_input(context):
question_data = main(context)
options, correct_answer = shuffle_options(question_data)
return question_data["question"], options, correct_answer
def check_answer(choice, correct_answer):
if choice == correct_answer:
return f'<p style="color: #28a745;">Correct!</p>'
else:
return f'<p style="color: #dc3545;">Incorrect ! Try again.</p>'
with gr.Blocks() as demo:
gr.Markdown("# MCQ Generator")
with gr.Row():
context_input = gr.Textbox(label="Context Paragraph", lines=5)
generate_button = gr.Button("Generate Question")
question_output = gr.Textbox(label="Question")
options_radio = gr.Radio(label="Options", choices=[])
submit_button = gr.Button("Submit Answer")
result_output = gr.HTML()
correct_answer = gr.State()
def update_interface(question, options, correct):
return {
question_output: question,
options_radio: gr.Radio(choices=options, label="Options"),
correct_answer: correct
}
generate_button.click(
process_input,
inputs=[context_input],
outputs=[question_output, options_radio, correct_answer]
).then(
update_interface,
inputs=[question_output, options_radio, correct_answer],
outputs=[question_output, options_radio, correct_answer]
)
submit_button.click(
check_answer,
inputs=[options_radio, correct_answer],
outputs=[result_output]
)
if __name__=="__main__":
demo.launch(debug=True)