|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from features.text_utils import post_process_answer |
|
from features.graph_utils import find_best_cluster |
|
from optimum.onnxruntime import ORTModelForQuestionAnswering |
|
|
|
class QAEnsembleModel(nn.Module): |
|
|
|
def __init__(self, model_name, model_checkpoints, entity_dict, |
|
thr=0.1, device="cpu"): |
|
super(QAEnsembleModel, self).__init__() |
|
|
|
self.models = [] |
|
self.tokenizers = [] |
|
for model_checkpoint in model_checkpoints: |
|
model = ORTModelForQuestionAnswering.from_pretrained(model_name, from_transformers=True) |
|
model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False) |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.models.append(model) |
|
self.tokenizers.append(tokenizer) |
|
self.entity_dict = entity_dict |
|
self.thr = thr |
|
|
|
def forward(self, question, texts, ranking_scores=None): |
|
if ranking_scores is None: |
|
ranking_scores = np.ones((len(texts),)) |
|
|
|
curr_answers = [] |
|
curr_scores = [] |
|
best_score = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, (model, tokenizer) in enumerate(zip(self.models, self.tokenizers)): |
|
for text, score in zip(texts, ranking_scores): |
|
|
|
inputs = tokenizer(question, text, return_tensors="pt") |
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
|
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
start_logits = outputs.start_logits |
|
end_logits = outputs.end_logits |
|
|
|
start_idx = torch.argmax(start_logits) |
|
end_idx = torch.argmax(end_logits) |
|
|
|
answer_ids = input_ids[0][start_idx:end_idx+1] |
|
|
|
answer_text = tokenizer.decode(answer_ids) |
|
|
|
answer_score = torch.max(start_logits) + torch.max(end_logits) |
|
|
|
answer_text = answer_text.numpy() |
|
answer_score = answer_score.numpy() |
|
if answer_score > self.thr: |
|
curr_answers.append(answer_text) |
|
curr_scores.append(answer_score) |
|
answer_score = answer_score * score |
|
if i == 0: |
|
if answer_score > best_score: |
|
answer = answer_text |
|
best_score = answer_score |
|
if len(curr_answers) == 0: |
|
return None |
|
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers] |
|
answer = post_process_answer(answer, self.entity_dict) |
|
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict) |
|
return new_best_answer |
|
|