Question Answering
Transformers
Safetensors
GagaLey's picture
framework
7bf4b88
raw
history blame
2.79 kB
"""
input: query, query_id, candidates_ids
output: pred_dict: {node_id: similarity}
"""
import sys
from pathlib import Path
# Get the absolute path of the current script
current_file = Path(__file__).resolve()
project_root = current_file.parents[2]
# Add the project root to the system path
sys.path.append(str(project_root))
import torch
from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
class Ada(ModelForSTaRKQA):
def __init__(self, skb, dataset_name, device):
super(Ada, self).__init__(skb)
self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/"
self.query_emb_path = self.emb_dir + "text-embedding-ada-002/query/query_emb_dict.pt"
self.query_emb_dict = torch.load(self.query_emb_path)
# print(f"777, {self.query_emb_path}")
self.candidate_emb_path = self.emb_dir + "text-embedding-ada-002/doc/candidate_emb_dict.pt"
self.candidate_emb_dict = torch.load(self.candidate_emb_path)
self.device = device
assert len(self.candidate_emb_dict) == len(self.candidate_ids)
candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
def score(self, query, q_id, candidate_ids):
"""
pred_dict[node_id] = similarity (tensor)
"""
# Dimension of query_emb: torch.Size([1, emb_dim])
query_emb = self.query_emb_dict[q_id].view(1, -1)
# Dimension of candidates_embs: torch.Size([num_candidates, emb_dim])
# # candidates_embs = self.candidate_embs[candidates_ids]
candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids]
candidates_embs = torch.cat(candi_embs, dim=0).to(self.device)
# Dimension of similarity: torch.Size([num_candidates])
similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu()
pred_dict = {}
for i in range(len(candidate_ids)):
pred_dict[candidate_ids[i]] = similarity[i].item()
return pred_dict
def retrieve(self, query, q_id, topk, node_type=None):
# Dimension of query_emb: torch.Size([1, emb_dim])
query_emb = self.query_emb_dict[q_id].view(1, -1)
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
if isinstance(query, str):
pred_dict = dict(zip(self.candidate_ids, similarity.view(-1)))
sorted_pred_ids = sorted(pred_dict, key=lambda x: pred_dict[x], reverse=True)
selected_pred_ids = sorted_pred_ids[:topk]
pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids}
print(f"sorted: {pred_dict}")
return pred_dict