MoR / models /vss.py
GagaLey's picture
framework
7bf4b88
raw
history blame
2.33 kB
import os.path as osp
import torch
from typing import Any, Union, List, Dict
from models.model import ModelForSTaRKQA
from tqdm import tqdm
from stark_qa.evaluator import Evaluator
import sys
sys.path.append("stark/")
class VSS(ModelForSTaRKQA):
def __init__(self,
skb,
query_emb_dir: str,
candidates_emb_dir: str,
emb_model: str = 'text-embedding-ada-002',
device: str = 'cuda'):
"""
Vector Similarity Search
Args:
skb (SemiStruct): Knowledge base.
query_emb_dir (str): Directory to query embeddings.
candidates_emb_dir (str): Directory to candidate embeddings.
emb_model (str): Embedding model name.
"""
super(VSS, self).__init__(skb, query_emb_dir=query_emb_dir)
self.emb_model = emb_model
self.candidates_emb_dir = candidates_emb_dir
self.device = device
self.evaluator = Evaluator(self.candidate_ids, device)
candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt')
candidate_emb_dict = torch.load(candidate_emb_path)
print(f'Loaded candidate_emb_dict from {candidate_emb_path}!')
assert len(candidate_emb_dict) == len(self.candidate_ids)
candidate_embs = [candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
def forward(self,
query: Union[str, List[str]],
query_id: Union[int, List[int]],
**kwargs: Any) -> dict:
"""
Forward pass to compute similarity scores for the given query.
Args:
query (str): Query string.
query_id (int): Query index.
Returns:
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
"""
query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model, **kwargs)
similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
if isinstance(query, str):
return dict(zip(self.candidate_ids, similarity.view(-1)))
else:
return torch.LongTensor(self.candidate_ids), similarity.t()