SentenceTransformer / sentence_transformers /evaluation /InformationRetrievalEvaluator.py
lengocduc195's picture
pushNe
2359bda
from . import SentenceEvaluator
import torch
from torch import Tensor
import logging
from tqdm import tqdm, trange
from ..util import cos_sim, dot_score
import os
import numpy as np
from typing import List, Tuple, Dict, Set, Callable
import heapq
logger = logging.getLogger(__name__)
class InformationRetrievalEvaluator(SentenceEvaluator):
"""
This class evaluates an Information Retrieval (IR) setting.
Given a set of queries and a large corpus set. It will retrieve for each query the top-k most similar document. It measures
Mean Reciprocal Rank (MRR), Recall@k, and Normalized Discounted Cumulative Gain (NDCG)
"""
def __init__(self,
queries: Dict[str, str], #qid => query
corpus: Dict[str, str], #cid => doc
relevant_docs: Dict[str, Set[str]], #qid => Set[cid]
corpus_chunk_size: int = 50000,
mrr_at_k: List[int] = [10],
ndcg_at_k: List[int] = [10],
accuracy_at_k: List[int] = [1, 3, 5, 10],
precision_recall_at_k: List[int] = [1, 3, 5, 10],
map_at_k: List[int] = [100],
show_progress_bar: bool = False,
batch_size: int = 32,
name: str = '',
write_csv: bool = True,
score_functions: List[Callable[[Tensor, Tensor], Tensor] ] = {'cos_sim': cos_sim, 'dot_score': dot_score}, #Score function, higher=more similar
main_score_function: str = None
):
self.queries_ids = []
for qid in queries:
if qid in relevant_docs and len(relevant_docs[qid]) > 0:
self.queries_ids.append(qid)
self.queries = [queries[qid] for qid in self.queries_ids]
self.corpus_ids = list(corpus.keys())
self.corpus = [corpus[cid] for cid in self.corpus_ids]
self.relevant_docs = relevant_docs
self.corpus_chunk_size = corpus_chunk_size
self.mrr_at_k = mrr_at_k
self.ndcg_at_k = ndcg_at_k
self.accuracy_at_k = accuracy_at_k
self.precision_recall_at_k = precision_recall_at_k
self.map_at_k = map_at_k
self.show_progress_bar = show_progress_bar
self.batch_size = batch_size
self.name = name
self.write_csv = write_csv
self.score_functions = score_functions
self.score_function_names = sorted(list(self.score_functions.keys()))
self.main_score_function = main_score_function
if name:
name = "_" + name
self.csv_file: str = "Information-Retrieval_evaluation" + name + "_results.csv"
self.csv_headers = ["epoch", "steps"]
for score_name in self.score_function_names:
for k in accuracy_at_k:
self.csv_headers.append("{}-Accuracy@{}".format(score_name, k))
for k in precision_recall_at_k:
self.csv_headers.append("{}-Precision@{}".format(score_name, k))
self.csv_headers.append("{}-Recall@{}".format(score_name, k))
for k in mrr_at_k:
self.csv_headers.append("{}-MRR@{}".format(score_name, k))
for k in ndcg_at_k:
self.csv_headers.append("{}-NDCG@{}".format(score_name, k))
for k in map_at_k:
self.csv_headers.append("{}-MAP@{}".format(score_name, k))
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs) -> float:
if epoch != -1:
out_txt = " after epoch {}:".format(epoch) if steps == -1 else " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logger.info("Information Retrieval Evaluation on " + self.name + " dataset" + out_txt)
scores = self.compute_metrices(model, *args, **kwargs)
# Write results to disc
if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
if not os.path.isfile(csv_path):
fOut = open(csv_path, mode="w", encoding="utf-8")
fOut.write(",".join(self.csv_headers))
fOut.write("\n")
else:
fOut = open(csv_path, mode="a", encoding="utf-8")
output_data = [epoch, steps]
for name in self.score_function_names:
for k in self.accuracy_at_k:
output_data.append(scores[name]['accuracy@k'][k])
for k in self.precision_recall_at_k:
output_data.append(scores[name]['precision@k'][k])
output_data.append(scores[name]['recall@k'][k])
for k in self.mrr_at_k:
output_data.append(scores[name]['mrr@k'][k])
for k in self.ndcg_at_k:
output_data.append(scores[name]['ndcg@k'][k])
for k in self.map_at_k:
output_data.append(scores[name]['map@k'][k])
fOut.write(",".join(map(str, output_data)))
fOut.write("\n")
fOut.close()
if self.main_score_function is None:
return max([scores[name]['map@k'][max(self.map_at_k)] for name in self.score_function_names])
else:
return scores[self.main_score_function]['map@k'][max(self.map_at_k)]
def compute_metrices(self, model, corpus_model = None, corpus_embeddings: Tensor = None) -> Dict[str, float]:
if corpus_model is None:
corpus_model = model
max_k = max(max(self.mrr_at_k), max(self.ndcg_at_k), max(self.accuracy_at_k), max(self.precision_recall_at_k), max(self.map_at_k))
# Compute embedding for the queries
query_embeddings = model.encode(self.queries, show_progress_bar=self.show_progress_bar, batch_size=self.batch_size, convert_to_tensor=True)
queries_result_list = {}
for name in self.score_functions:
queries_result_list[name] = [[] for _ in range(len(query_embeddings))]
#Iterate over chunks of the corpus
for corpus_start_idx in trange(0, len(self.corpus), self.corpus_chunk_size, desc='Corpus Chunks', disable=not self.show_progress_bar):
corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(self.corpus))
#Encode chunk of corpus
if corpus_embeddings is None:
sub_corpus_embeddings = corpus_model.encode(self.corpus[corpus_start_idx:corpus_end_idx], show_progress_bar=False, batch_size=self.batch_size, convert_to_tensor=True)
else:
sub_corpus_embeddings = corpus_embeddings[corpus_start_idx:corpus_end_idx]
#Compute cosine similarites
for name, score_function in self.score_functions.items():
pair_scores = score_function(query_embeddings, sub_corpus_embeddings)
#Get top-k values
pair_scores_top_k_values, pair_scores_top_k_idx = torch.topk(pair_scores, min(max_k, len(pair_scores[0])), dim=1, largest=True, sorted=False)
pair_scores_top_k_values = pair_scores_top_k_values.cpu().tolist()
pair_scores_top_k_idx = pair_scores_top_k_idx.cpu().tolist()
for query_itr in range(len(query_embeddings)):
for sub_corpus_id, score in zip(pair_scores_top_k_idx[query_itr], pair_scores_top_k_values[query_itr]):
corpus_id = self.corpus_ids[corpus_start_idx+sub_corpus_id]
if len(queries_result_list[name][query_itr]) < max_k:
heapq.heappush(queries_result_list[name][query_itr], (score, corpus_id)) # heaqp tracks the quantity of the first element in the tuple
else:
heapq.heappushpop(queries_result_list[name][query_itr], (score, corpus_id))
for name in queries_result_list:
for query_itr in range(len(queries_result_list[name])):
for doc_itr in range(len(queries_result_list[name][query_itr])):
score, corpus_id = queries_result_list[name][query_itr][doc_itr]
queries_result_list[name][query_itr][doc_itr] = {'corpus_id': corpus_id, 'score': score}
logger.info("Queries: {}".format(len(self.queries)))
logger.info("Corpus: {}\n".format(len(self.corpus)))
#Compute scores
scores = {name: self.compute_metrics(queries_result_list[name]) for name in self.score_functions}
#Output
for name in self.score_function_names:
logger.info("Score-Function: {}".format(name))
self.output_scores(scores[name])
return scores
def compute_metrics(self, queries_result_list: List[object]):
# Init score computation values
num_hits_at_k = {k: 0 for k in self.accuracy_at_k}
precisions_at_k = {k: [] for k in self.precision_recall_at_k}
recall_at_k = {k: [] for k in self.precision_recall_at_k}
MRR = {k: 0 for k in self.mrr_at_k}
ndcg = {k: [] for k in self.ndcg_at_k}
AveP_at_k = {k: [] for k in self.map_at_k}
# Compute scores on results
for query_itr in range(len(queries_result_list)):
query_id = self.queries_ids[query_itr]
# Sort scores
top_hits = sorted(queries_result_list[query_itr], key=lambda x: x['score'], reverse=True)
query_relevant_docs = self.relevant_docs[query_id]
# Accuracy@k - We count the result correct, if at least one relevant doc is accross the top-k documents
for k_val in self.accuracy_at_k:
for hit in top_hits[0:k_val]:
if hit['corpus_id'] in query_relevant_docs:
num_hits_at_k[k_val] += 1
break
# Precision and Recall@k
for k_val in self.precision_recall_at_k:
num_correct = 0
for hit in top_hits[0:k_val]:
if hit['corpus_id'] in query_relevant_docs:
num_correct += 1
precisions_at_k[k_val].append(num_correct / k_val)
recall_at_k[k_val].append(num_correct / len(query_relevant_docs))
# MRR@k
for k_val in self.mrr_at_k:
for rank, hit in enumerate(top_hits[0:k_val]):
if hit['corpus_id'] in query_relevant_docs:
MRR[k_val] += 1.0 / (rank + 1)
break
# NDCG@k
for k_val in self.ndcg_at_k:
predicted_relevance = [1 if top_hit['corpus_id'] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]]
true_relevances = [1] * len(query_relevant_docs)
ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val)
ndcg[k_val].append(ndcg_value)
# MAP@k
for k_val in self.map_at_k:
num_correct = 0
sum_precisions = 0
for rank, hit in enumerate(top_hits[0:k_val]):
if hit['corpus_id'] in query_relevant_docs:
num_correct += 1
sum_precisions += num_correct / (rank + 1)
avg_precision = sum_precisions / min(k_val, len(query_relevant_docs))
AveP_at_k[k_val].append(avg_precision)
# Compute averages
for k in num_hits_at_k:
num_hits_at_k[k] /= len(self.queries)
for k in precisions_at_k:
precisions_at_k[k] = np.mean(precisions_at_k[k])
for k in recall_at_k:
recall_at_k[k] = np.mean(recall_at_k[k])
for k in ndcg:
ndcg[k] = np.mean(ndcg[k])
for k in MRR:
MRR[k] /= len(self.queries)
for k in AveP_at_k:
AveP_at_k[k] = np.mean(AveP_at_k[k])
return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k}
def output_scores(self, scores):
for k in scores['accuracy@k']:
logger.info("Accuracy@{}: {:.2f}%".format(k, scores['accuracy@k'][k]*100))
for k in scores['precision@k']:
logger.info("Precision@{}: {:.2f}%".format(k, scores['precision@k'][k]*100))
for k in scores['recall@k']:
logger.info("Recall@{}: {:.2f}%".format(k, scores['recall@k'][k]*100))
for k in scores['mrr@k']:
logger.info("MRR@{}: {:.4f}".format(k, scores['mrr@k'][k]))
for k in scores['ndcg@k']:
logger.info("NDCG@{}: {:.4f}".format(k, scores['ndcg@k'][k]))
for k in scores['map@k']:
logger.info("MAP@{}: {:.4f}".format(k, scores['map@k'][k]))
@staticmethod
def compute_dcg_at_k(relevances, k):
dcg = 0
for i in range(min(len(relevances), k)):
dcg += relevances[i] / np.log2(i + 2) #+2 as we start our idx at 0
return dcg