|
|
|
""" |
|
input: query, query_id, candidates_ids |
|
output: pred_dict: {node_id: similarity} |
|
""" |
|
|
|
import sys |
|
from pathlib import Path |
|
|
|
current_file = Path(__file__).resolve() |
|
project_root = current_file.parents[2] |
|
|
|
sys.path.append(str(project_root)) |
|
import torch |
|
from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA |
|
|
|
class Ada(ModelForSTaRKQA): |
|
def __init__(self, skb, dataset_name, device): |
|
super(Ada, self).__init__(skb) |
|
self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/" |
|
self.query_emb_path = self.emb_dir + "text-embedding-ada-002/query/query_emb_dict.pt" |
|
self.query_emb_dict = torch.load(self.query_emb_path) |
|
|
|
|
|
self.candidate_emb_path = self.emb_dir + "text-embedding-ada-002/doc/candidate_emb_dict.pt" |
|
self.candidate_emb_dict = torch.load(self.candidate_emb_path) |
|
self.device = device |
|
|
|
assert len(self.candidate_emb_dict) == len(self.candidate_ids) |
|
|
|
candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids] |
|
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device) |
|
|
|
def score(self, query, q_id, candidate_ids): |
|
""" |
|
pred_dict[node_id] = similarity (tensor) |
|
|
|
""" |
|
|
|
query_emb = self.query_emb_dict[q_id].view(1, -1) |
|
|
|
|
|
candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids] |
|
candidates_embs = torch.cat(candi_embs, dim=0).to(self.device) |
|
|
|
similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu() |
|
pred_dict = {} |
|
for i in range(len(candidate_ids)): |
|
pred_dict[candidate_ids[i]] = similarity[i].item() |
|
|
|
return pred_dict |
|
|
|
def retrieve(self, query, q_id, topk, node_type=None): |
|
|
|
query_emb = self.query_emb_dict[q_id].view(1, -1) |
|
|
|
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu() |
|
if isinstance(query, str): |
|
pred_dict = dict(zip(self.candidate_ids, similarity.view(-1))) |
|
|
|
sorted_pred_ids = sorted(pred_dict, key=lambda x: pred_dict[x], reverse=True) |
|
selected_pred_ids = sorted_pred_ids[:topk] |
|
pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids} |
|
print(f"sorted: {pred_dict}") |
|
|
|
return pred_dict |
|
|