Question Answering
Transformers
Safetensors
GagaLey's picture
framework
7bf4b88
raw
history blame
4.19 kB
"""
input: query, node_type, topk
output: pred_dict: {node_id: score}
"""
import bm25s
from tqdm import tqdm
import sys
from pathlib import Path
# Get the absolute path of the current script
current_file = Path(__file__).resolve()
project_root = current_file.parents[2]
# Add the project root to the system path
sys.path.append(str(project_root))
from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA
target_type = {'amazon': 'product', 'prime': 'combine', 'mag': 'paper'}
class BM25(ModelForSTaRKQA):
def __init__(self, skb, dataset_name):
super(BM25, self).__init__(skb)
self.retrievers = {}
self.text_to_ids = {}
type_names = skb.node_type_lst()
self.nodeid_to_index = {}
self.target_type = target_type[dataset_name]
if self.target_type not in type_names:
ids = skb.get_candidate_ids()
corpus = [skb.get_doc_info(id) for id in tqdm(ids, desc=f"Gathering docs for combine")]
retriever = bm25s.BM25(corpus=corpus)
retriever.index(bm25s.tokenize(corpus))
# Build hash map from text to node_id
text_to_id = {hash(text): id for text, id in zip(corpus, ids)}
# Store the retriever and text_to_id by type_name
self.retrievers[self.target_type] = retriever
self.text_to_ids[self.target_type] = text_to_id
self.nodeid_to_index[self.target_type] = {id: i for i, id in enumerate(ids)}
# Initialize retrievers and text-to-index maps for each type_name
for type_name in type_names:
ids = skb.get_node_ids_by_type(type_name)
# we manually replace '&' with '_and_' to avoid the error in BM25, because BM25 uses '&' as a special character and will not tokenize it
corpus = [skb.get_doc_info(id).replace('&', '_and_').replace('P.O.R', 'P_dot_O_dot_R') for id in tqdm(ids, desc=f"Gathering docs for {type_name}")]
# Create the BM25 model for the current type_name
retriever = bm25s.BM25(corpus=corpus)
retriever.index(bm25s.tokenize(corpus))
# Build hash map from text to index
text_to_id = {hash(text): id for text, id in zip(corpus, ids)}
# Store the retriever and text_to_id by type_name
self.retrievers[type_name] = retriever
self.text_to_ids[type_name] = text_to_id
# build map from node_id to index
self.nodeid_to_index[type_name] = {id: i for i, id in enumerate(ids)}
def score(self, query, q_id, candidate_ids):
pred_dict = {}
for c_id in candidate_ids:
type_name = self.skb.get_node_type_by_id(c_id)
score = self.retrievers[type_name].get_scores(list(bm25s.tokenize(query)[1].keys()))[self.nodeid_to_index[type_name][c_id]] # save the query tokens
pred_dict[c_id] = score
# print(f"999, {pred_dict}")
return pred_dict
def retrieve(self, query, q_id, topk, node_type=None):
"""
Forward pass to compute similarity scores for the given query.
Args:
query (str): Query string.
Returns:
pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores.
"""
if '&' in query:
query = query.replace('&', '_and_')
if 'P.O.R' in query:
query = query.replace('P.O.R', 'P_dot_O_dot_R')
if isinstance(node_type, list):
if len(node_type) > 1:
node_type = 'combine'
else:
node_type = node_type[0]
results, scores = self.retrievers[node_type].retrieve(bm25s.tokenize(query), k=topk)
ids = [self.text_to_ids[node_type][hash(result.item())] for result in results[0]]
scores = scores[0].tolist()
pred_dict = dict(zip(ids, scores))
# print(f"666, {pred_dict}")
return pred_dict
if __name__ == '__main__':
print(f"Testing BM25")