File size: 4,192 Bytes
aa8b9d8
 
 
2fe0a0f
6fb8aeb
 
cb49ed3
aa8b9d8
 
 
 
 
 
2fe0a0f
 
 
aa8b9d8
cb49ed3
aa8b9d8
2fe0a0f
 
 
 
 
 
aa8b9d8
 
 
 
 
 
 
 
 
 
2fe0a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa8b9d8
2fe0a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa8b9d8
2fe0a0f
 
 
aa8b9d8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
import torch.nn as nn
# from transformers import AutoModelForQuestionAnswering, pipeline
from features.text_utils import post_process_answer
from features.graph_utils import find_best_cluster
from optimum.onnxruntime import ORTModelForQuestionAnswering

class QAEnsembleModel(nn.Module):

    def __init__(self, model_name, model_checkpoints, entity_dict,
                 thr=0.1, device="cpu"):
        super(QAEnsembleModel, self).__init__()
        # self.nlps = []
        self.models = []
        self.tokenizers = []
        for model_checkpoint in model_checkpoints:
            model = ORTModelForQuestionAnswering.from_pretrained(model_name, from_transformers=True)#.half()
            model.load_state_dict(torch.load(model_checkpoint, map_location=torch.device('cpu')), strict=False)
            # nlp = pipeline('question-answering', model=model,
            #                tokenizer=model_name, device=device)
            # self.nlps.append(nlp)
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.models.append(model)
            self.tokenizers.append(tokenizer)
        self.entity_dict = entity_dict
        self.thr = thr

    def forward(self, question, texts, ranking_scores=None):
        if ranking_scores is None:
            ranking_scores = np.ones((len(texts),))

        curr_answers = []
        curr_scores = []
        best_score = 0
        # for i, nlp in enumerate(self.nlps):
        #     for text, score in zip(texts, ranking_scores):
        #         QA_input = {
        #             'question': question,
        #             'context': text
        #         }
        #         res = nlp(QA_input)
        #         # print(res)
        #         if res["score"] > self.thr:
        #             curr_answers.append(res["answer"])
        #             curr_scores.append(res["score"])
        #         res["score"] = res["score"] * score
        #         if i == 0:
        #             if res["score"] > best_score:
        #                 answer = res["answer"]
        #                 best_score = res["score"]

        for i, (model, tokenizer) in enumerate(zip(self.models, self.tokenizers)):
            for text, score in zip(texts, ranking_scores):
                # Encode the question and context as input ids and attention mask
                inputs = tokenizer(question, text, return_tensors="pt")
                input_ids = inputs["input_ids"]
                attention_mask = inputs["attention_mask"]
                # Get the start and end logits from the model
                outputs = model(input_ids, attention_mask=attention_mask)
                start_logits = outputs.start_logits
                end_logits = outputs.end_logits
                # Get the most likely start and end indices
                start_idx = torch.argmax(start_logits)
                end_idx = torch.argmax(end_logits)
                # Get the answer span from the input ids
                answer_ids = input_ids[0][start_idx:end_idx+1]
                # Decode the answer ids to get the answer text
                answer_text = tokenizer.decode(answer_ids)
                # Get the answer score from the start and end logits
                answer_score = torch.max(start_logits) + torch.max(end_logits)
                # Convert to numpy values
                answer_text = answer_text.numpy()
                answer_score = answer_score.numpy()
                if answer_score > self.thr:
                    curr_answers.append(answer_text)
                    curr_scores.append(answer_score)
                answer_score = answer_score * score
                if i == 0:
                    if answer_score > best_score:
                        answer = answer_text
                        best_score = answer_score
        if len(curr_answers) == 0:
            return None
        curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
        answer = post_process_answer(answer, self.entity_dict)
        new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
        return new_best_answer