import sys import os import os.path as osp from typing import Any, Union, List, Dict import torch import torch.nn as nn from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings, get_contriever_embeddings from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings from stark_qa.evaluator import Evaluator class ModelForSTaRKQA(nn.Module): def __init__(self, skb, query_emb_dir='.'): """ Initializes the model with the given knowledge base. Args: skb: Knowledge base containing candidate information. """ super(ModelForSTaRKQA, self).__init__() self.skb = skb self.candidate_ids = skb.candidate_ids self.num_candidates = skb.num_candidates self.query_emb_dir = query_emb_dir query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt') if os.path.exists(query_emb_path): print(f'Load query embeddings from {query_emb_path}') self.query_emb_dict = torch.load(query_emb_path) else: self.query_emb_dict = {} self.evaluator = Evaluator(self.candidate_ids) def forward(self, query: Union[str, List[str]], candidates: List[int] = None, query_id: Union[int, List[int]] = None, **kwargs: Any) -> Dict[str, Any]: """ Forward pass to compute predictions for the given query. Args: query (Union[str, list]): Query string or a list of query strings. candidates (Union[list, None]): A list of candidate ids (optional). query_id (Union[int, list, None]): Query index (optional). Returns: pred_dict (dict): A dictionary of predicted scores or answer ids. """ raise NotImplementedError def get_query_emb(self, query: Union[str, List[str]], query_id: Union[int, List[int]], emb_model: str = 'text-embedding-ada-002', **encode_kwargs) -> torch.Tensor: """ Retrieves or computes the embedding for the given query. Args: query (str): Query string. query_id (int): Query index. emb_model (str): Embedding model to use. Returns: query_emb (torch.Tensor): Query embedding. """ if isinstance(query_id, int): query_id = [query_id] if isinstance(query, str): query = [query] if query_id is None: query_emb = get_embeddings(query, emb_model, **encode_kwargs) elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))): query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0) else: query_emb = get_embeddings(query, emb_model, **encode_kwargs) for qid, emb in zip(query_id, query_emb): self.query_emb_dict[qid] = emb.view(1, -1) torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt')) query_emb = query_emb.view(len(query), -1) return query_emb def evaluate(self, pred_dict: Dict[int, float], answer_ids: Union[torch.LongTensor, List[Any]], metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], **kwargs: Any) -> Dict[str, float]: """ Evaluates the predictions using the specified metrics. Args: pred_dict (Dict[int, float]): Predicted answer ids or scores. answer_ids (torch.LongTensor): Ground truth answer ids. metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', 'precision@k', 'map@k', 'ndcg@k'. Returns: Dict[str, float]: A dictionary of evaluation metrics. """ return self.evaluator(pred_dict, answer_ids, metrics) def evaluate_batch(self, pred_ids: List[int], pred: torch.Tensor, answer_ids: Union[torch.LongTensor, List[Any]], metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], **kwargs: Any) -> Dict[str, float]: return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics) def get_embeddings(text, model_name, **encode_kwargs): """ Get embeddings for the given text using the specified model. Args: model_name (str): Model name. text (Union[str, List[str]]): The input text to be embedded. Returns: torch.Tensor: Embedding of the input text. """ if isinstance(text, str): text = [text] if 'GritLM' in model_name: emb = get_gritlm_embeddings(text, model_name, **encode_kwargs) elif 'LLM2Vec' in model_name: emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs) elif 'all-mpnet-base-v2' in model_name: emb = get_sentence_transformer_embeddings(text) elif 'contriever' in model_name: emb = get_contriever_embeddings(text) else: emb = get_api_embeddings(text, model_name, **encode_kwargs) return emb.view(len(text), -1)