File size: 2,333 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os.path as osp
import torch
from typing import Any, Union, List, Dict
from models.model import ModelForSTaRKQA
from tqdm import tqdm
from stark_qa.evaluator import Evaluator
import sys
sys.path.append("stark/")


class VSS(ModelForSTaRKQA):
    
    def __init__(self, 
                 skb, 
                 query_emb_dir: str, 
                 candidates_emb_dir: str, 
                 emb_model: str = 'text-embedding-ada-002',
                 device: str = 'cuda'):
        """
        Vector Similarity Search

        Args:
            skb (SemiStruct): Knowledge base.
            query_emb_dir (str): Directory to query embeddings.
            candidates_emb_dir (str): Directory to candidate embeddings.
            emb_model (str): Embedding model name.
        """
        super(VSS, self).__init__(skb, query_emb_dir=query_emb_dir)
        self.emb_model = emb_model
        self.candidates_emb_dir = candidates_emb_dir
        self.device = device
        self.evaluator = Evaluator(self.candidate_ids, device)

        candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt')
        candidate_emb_dict = torch.load(candidate_emb_path)
        print(f'Loaded candidate_emb_dict from {candidate_emb_path}!')

        assert len(candidate_emb_dict) == len(self.candidate_ids)
        candidate_embs = [candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids]
        self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device)
    
    def forward(self, 
                query: Union[str, List[str]], 
                query_id: Union[int, List[int]], 
                **kwargs: Any) -> dict:
        """
        Forward pass to compute similarity scores for the given query.

        Args:
            query (str): Query string.
            query_id (int): Query index.

        Returns:
            pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
        """
        query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model, **kwargs)
        similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu()
        if isinstance(query, str):
            return dict(zip(self.candidate_ids, similarity.view(-1)))
        else:
            return torch.LongTensor(self.candidate_ids), similarity.t()