e2eqa-wiki / models /qa_model.py
foxxy-hm's picture
Update models/qa_model.py
import numpy as np
import torch
import torch.nn as nn
# from transformers import AutoModelForQuestionAnswering, pipeline
from features.text_utils import post_process_answer
from features.graph_utils import find_best_cluster
from optimum.onnxruntime import ORTModelForQuestionAnswering
class QAEnsembleModel(nn.Module):
def __init__(self, model_name, model_checkpoints, entity_dict,
thr=0.1, device="cpu"):
super(QAEnsembleModel, self).__init__()
# self.nlps = []
self.models = []
self.tokenizers = []
for model_checkpoint in model_checkpoints:
model = ORTModelForQuestionAnswering.from_pretrained(model_name, from_transformers=True)#.half()
model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
# nlp = pipeline('question-answering', model=model,
# tokenizer=model_name, device=device)
# self.nlps.append(nlp)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.entity_dict = entity_dict
self.thr = thr
def forward(self, question, texts, ranking_scores=None):
if ranking_scores is None:
ranking_scores = np.ones((len(texts),))
curr_answers = []
curr_scores = []
best_score = 0
# for i, nlp in enumerate(self.nlps):
# for text, score in zip(texts, ranking_scores):
# QA_input = {
# 'question': question,
# 'context': text
# }
# res = nlp(QA_input)
# # print(res)
# if res["score"] > self.thr:
# curr_answers.append(res["answer"])
# curr_scores.append(res["score"])
# res["score"] = res["score"] * score
# if i == 0:
# if res["score"] > best_score:
# answer = res["answer"]
# best_score = res["score"]
for i, (model, tokenizer) in enumerate(zip(self.models, self.tokenizers)):
for text, score in zip(texts, ranking_scores):
# Encode the question and context as input ids and attention mask
inputs = tokenizer(question, text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Get the start and end logits from the model
outputs = model(input_ids, attention_mask=attention_mask)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# Get the most likely start and end indices
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
# Get the answer span from the input ids
answer_ids = input_ids[0][start_idx:end_idx+1]
# Decode the answer ids to get the answer text
answer_text = tokenizer.decode(answer_ids)
# Get the answer score from the start and end logits
answer_score = torch.max(start_logits) + torch.max(end_logits)
# Convert to numpy values
answer_text = answer_text.numpy()
answer_score = answer_score.numpy()
if answer_score > self.thr:
answer_score = answer_score * score
if i == 0:
if answer_score > best_score:
answer = answer_text
best_score = answer_score
if len(curr_answers) == 0:
return None
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
answer = post_process_answer(answer, self.entity_dict)
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
return new_best_answer