|
from . import SentenceEvaluator |
|
import torch |
|
from torch import Tensor |
|
import logging |
|
from tqdm import tqdm, trange |
|
from ..util import cos_sim, dot_score |
|
import os |
|
import numpy as np |
|
from typing import List, Tuple, Dict, Set, Callable |
|
import heapq |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class InformationRetrievalEvaluator(SentenceEvaluator): |
|
""" |
|
This class evaluates an Information Retrieval (IR) setting. |
|
|
|
Given a set of queries and a large corpus set. It will retrieve for each query the top-k most similar document. It measures |
|
Mean Reciprocal Rank (MRR), Recall@k, and Normalized Discounted Cumulative Gain (NDCG) |
|
""" |
|
|
|
def __init__(self, |
|
queries: Dict[str, str], |
|
corpus: Dict[str, str], |
|
relevant_docs: Dict[str, Set[str]], |
|
corpus_chunk_size: int = 50000, |
|
mrr_at_k: List[int] = [10], |
|
ndcg_at_k: List[int] = [10], |
|
accuracy_at_k: List[int] = [1, 3, 5, 10], |
|
precision_recall_at_k: List[int] = [1, 3, 5, 10], |
|
map_at_k: List[int] = [100], |
|
show_progress_bar: bool = False, |
|
batch_size: int = 32, |
|
name: str = '', |
|
write_csv: bool = True, |
|
score_functions: List[Callable[[Tensor, Tensor], Tensor] ] = {'cos_sim': cos_sim, 'dot_score': dot_score}, |
|
main_score_function: str = None |
|
): |
|
|
|
self.queries_ids = [] |
|
for qid in queries: |
|
if qid in relevant_docs and len(relevant_docs[qid]) > 0: |
|
self.queries_ids.append(qid) |
|
|
|
self.queries = [queries[qid] for qid in self.queries_ids] |
|
|
|
self.corpus_ids = list(corpus.keys()) |
|
self.corpus = [corpus[cid] for cid in self.corpus_ids] |
|
|
|
self.relevant_docs = relevant_docs |
|
self.corpus_chunk_size = corpus_chunk_size |
|
self.mrr_at_k = mrr_at_k |
|
self.ndcg_at_k = ndcg_at_k |
|
self.accuracy_at_k = accuracy_at_k |
|
self.precision_recall_at_k = precision_recall_at_k |
|
self.map_at_k = map_at_k |
|
|
|
self.show_progress_bar = show_progress_bar |
|
self.batch_size = batch_size |
|
self.name = name |
|
self.write_csv = write_csv |
|
self.score_functions = score_functions |
|
self.score_function_names = sorted(list(self.score_functions.keys())) |
|
self.main_score_function = main_score_function |
|
|
|
if name: |
|
name = "_" + name |
|
|
|
self.csv_file: str = "Information-Retrieval_evaluation" + name + "_results.csv" |
|
self.csv_headers = ["epoch", "steps"] |
|
|
|
for score_name in self.score_function_names: |
|
for k in accuracy_at_k: |
|
self.csv_headers.append("{}-Accuracy@{}".format(score_name, k)) |
|
|
|
for k in precision_recall_at_k: |
|
self.csv_headers.append("{}-Precision@{}".format(score_name, k)) |
|
self.csv_headers.append("{}-Recall@{}".format(score_name, k)) |
|
|
|
for k in mrr_at_k: |
|
self.csv_headers.append("{}-MRR@{}".format(score_name, k)) |
|
|
|
for k in ndcg_at_k: |
|
self.csv_headers.append("{}-NDCG@{}".format(score_name, k)) |
|
|
|
for k in map_at_k: |
|
self.csv_headers.append("{}-MAP@{}".format(score_name, k)) |
|
|
|
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs) -> float: |
|
if epoch != -1: |
|
out_txt = " after epoch {}:".format(epoch) if steps == -1 else " in epoch {} after {} steps:".format(epoch, steps) |
|
else: |
|
out_txt = ":" |
|
|
|
logger.info("Information Retrieval Evaluation on " + self.name + " dataset" + out_txt) |
|
|
|
scores = self.compute_metrices(model, *args, **kwargs) |
|
|
|
|
|
if output_path is not None and self.write_csv: |
|
csv_path = os.path.join(output_path, self.csv_file) |
|
if not os.path.isfile(csv_path): |
|
fOut = open(csv_path, mode="w", encoding="utf-8") |
|
fOut.write(",".join(self.csv_headers)) |
|
fOut.write("\n") |
|
|
|
else: |
|
fOut = open(csv_path, mode="a", encoding="utf-8") |
|
|
|
output_data = [epoch, steps] |
|
for name in self.score_function_names: |
|
for k in self.accuracy_at_k: |
|
output_data.append(scores[name]['accuracy@k'][k]) |
|
|
|
for k in self.precision_recall_at_k: |
|
output_data.append(scores[name]['precision@k'][k]) |
|
output_data.append(scores[name]['recall@k'][k]) |
|
|
|
for k in self.mrr_at_k: |
|
output_data.append(scores[name]['mrr@k'][k]) |
|
|
|
for k in self.ndcg_at_k: |
|
output_data.append(scores[name]['ndcg@k'][k]) |
|
|
|
for k in self.map_at_k: |
|
output_data.append(scores[name]['map@k'][k]) |
|
|
|
fOut.write(",".join(map(str, output_data))) |
|
fOut.write("\n") |
|
fOut.close() |
|
|
|
if self.main_score_function is None: |
|
return max([scores[name]['map@k'][max(self.map_at_k)] for name in self.score_function_names]) |
|
else: |
|
return scores[self.main_score_function]['map@k'][max(self.map_at_k)] |
|
|
|
def compute_metrices(self, model, corpus_model = None, corpus_embeddings: Tensor = None) -> Dict[str, float]: |
|
if corpus_model is None: |
|
corpus_model = model |
|
|
|
max_k = max(max(self.mrr_at_k), max(self.ndcg_at_k), max(self.accuracy_at_k), max(self.precision_recall_at_k), max(self.map_at_k)) |
|
|
|
|
|
query_embeddings = model.encode(self.queries, show_progress_bar=self.show_progress_bar, batch_size=self.batch_size, convert_to_tensor=True) |
|
|
|
queries_result_list = {} |
|
for name in self.score_functions: |
|
queries_result_list[name] = [[] for _ in range(len(query_embeddings))] |
|
|
|
|
|
for corpus_start_idx in trange(0, len(self.corpus), self.corpus_chunk_size, desc='Corpus Chunks', disable=not self.show_progress_bar): |
|
corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(self.corpus)) |
|
|
|
|
|
if corpus_embeddings is None: |
|
sub_corpus_embeddings = corpus_model.encode(self.corpus[corpus_start_idx:corpus_end_idx], show_progress_bar=False, batch_size=self.batch_size, convert_to_tensor=True) |
|
else: |
|
sub_corpus_embeddings = corpus_embeddings[corpus_start_idx:corpus_end_idx] |
|
|
|
|
|
for name, score_function in self.score_functions.items(): |
|
pair_scores = score_function(query_embeddings, sub_corpus_embeddings) |
|
|
|
|
|
pair_scores_top_k_values, pair_scores_top_k_idx = torch.topk(pair_scores, min(max_k, len(pair_scores[0])), dim=1, largest=True, sorted=False) |
|
pair_scores_top_k_values = pair_scores_top_k_values.cpu().tolist() |
|
pair_scores_top_k_idx = pair_scores_top_k_idx.cpu().tolist() |
|
|
|
for query_itr in range(len(query_embeddings)): |
|
for sub_corpus_id, score in zip(pair_scores_top_k_idx[query_itr], pair_scores_top_k_values[query_itr]): |
|
corpus_id = self.corpus_ids[corpus_start_idx+sub_corpus_id] |
|
if len(queries_result_list[name][query_itr]) < max_k: |
|
heapq.heappush(queries_result_list[name][query_itr], (score, corpus_id)) |
|
else: |
|
heapq.heappushpop(queries_result_list[name][query_itr], (score, corpus_id)) |
|
|
|
for name in queries_result_list: |
|
for query_itr in range(len(queries_result_list[name])): |
|
for doc_itr in range(len(queries_result_list[name][query_itr])): |
|
score, corpus_id = queries_result_list[name][query_itr][doc_itr] |
|
queries_result_list[name][query_itr][doc_itr] = {'corpus_id': corpus_id, 'score': score} |
|
|
|
logger.info("Queries: {}".format(len(self.queries))) |
|
logger.info("Corpus: {}\n".format(len(self.corpus))) |
|
|
|
|
|
scores = {name: self.compute_metrics(queries_result_list[name]) for name in self.score_functions} |
|
|
|
|
|
for name in self.score_function_names: |
|
logger.info("Score-Function: {}".format(name)) |
|
self.output_scores(scores[name]) |
|
|
|
return scores |
|
|
|
|
|
def compute_metrics(self, queries_result_list: List[object]): |
|
|
|
num_hits_at_k = {k: 0 for k in self.accuracy_at_k} |
|
precisions_at_k = {k: [] for k in self.precision_recall_at_k} |
|
recall_at_k = {k: [] for k in self.precision_recall_at_k} |
|
MRR = {k: 0 for k in self.mrr_at_k} |
|
ndcg = {k: [] for k in self.ndcg_at_k} |
|
AveP_at_k = {k: [] for k in self.map_at_k} |
|
|
|
|
|
for query_itr in range(len(queries_result_list)): |
|
query_id = self.queries_ids[query_itr] |
|
|
|
|
|
top_hits = sorted(queries_result_list[query_itr], key=lambda x: x['score'], reverse=True) |
|
query_relevant_docs = self.relevant_docs[query_id] |
|
|
|
|
|
for k_val in self.accuracy_at_k: |
|
for hit in top_hits[0:k_val]: |
|
if hit['corpus_id'] in query_relevant_docs: |
|
num_hits_at_k[k_val] += 1 |
|
break |
|
|
|
|
|
for k_val in self.precision_recall_at_k: |
|
num_correct = 0 |
|
for hit in top_hits[0:k_val]: |
|
if hit['corpus_id'] in query_relevant_docs: |
|
num_correct += 1 |
|
|
|
precisions_at_k[k_val].append(num_correct / k_val) |
|
recall_at_k[k_val].append(num_correct / len(query_relevant_docs)) |
|
|
|
|
|
for k_val in self.mrr_at_k: |
|
for rank, hit in enumerate(top_hits[0:k_val]): |
|
if hit['corpus_id'] in query_relevant_docs: |
|
MRR[k_val] += 1.0 / (rank + 1) |
|
break |
|
|
|
|
|
for k_val in self.ndcg_at_k: |
|
predicted_relevance = [1 if top_hit['corpus_id'] in query_relevant_docs else 0 for top_hit in top_hits[0:k_val]] |
|
true_relevances = [1] * len(query_relevant_docs) |
|
|
|
ndcg_value = self.compute_dcg_at_k(predicted_relevance, k_val) / self.compute_dcg_at_k(true_relevances, k_val) |
|
ndcg[k_val].append(ndcg_value) |
|
|
|
|
|
for k_val in self.map_at_k: |
|
num_correct = 0 |
|
sum_precisions = 0 |
|
|
|
for rank, hit in enumerate(top_hits[0:k_val]): |
|
if hit['corpus_id'] in query_relevant_docs: |
|
num_correct += 1 |
|
sum_precisions += num_correct / (rank + 1) |
|
|
|
avg_precision = sum_precisions / min(k_val, len(query_relevant_docs)) |
|
AveP_at_k[k_val].append(avg_precision) |
|
|
|
|
|
for k in num_hits_at_k: |
|
num_hits_at_k[k] /= len(self.queries) |
|
|
|
for k in precisions_at_k: |
|
precisions_at_k[k] = np.mean(precisions_at_k[k]) |
|
|
|
for k in recall_at_k: |
|
recall_at_k[k] = np.mean(recall_at_k[k]) |
|
|
|
for k in ndcg: |
|
ndcg[k] = np.mean(ndcg[k]) |
|
|
|
for k in MRR: |
|
MRR[k] /= len(self.queries) |
|
|
|
for k in AveP_at_k: |
|
AveP_at_k[k] = np.mean(AveP_at_k[k]) |
|
|
|
|
|
return {'accuracy@k': num_hits_at_k, 'precision@k': precisions_at_k, 'recall@k': recall_at_k, 'ndcg@k': ndcg, 'mrr@k': MRR, 'map@k': AveP_at_k} |
|
|
|
|
|
def output_scores(self, scores): |
|
for k in scores['accuracy@k']: |
|
logger.info("Accuracy@{}: {:.2f}%".format(k, scores['accuracy@k'][k]*100)) |
|
|
|
for k in scores['precision@k']: |
|
logger.info("Precision@{}: {:.2f}%".format(k, scores['precision@k'][k]*100)) |
|
|
|
for k in scores['recall@k']: |
|
logger.info("Recall@{}: {:.2f}%".format(k, scores['recall@k'][k]*100)) |
|
|
|
for k in scores['mrr@k']: |
|
logger.info("MRR@{}: {:.4f}".format(k, scores['mrr@k'][k])) |
|
|
|
for k in scores['ndcg@k']: |
|
logger.info("NDCG@{}: {:.4f}".format(k, scores['ndcg@k'][k])) |
|
|
|
for k in scores['map@k']: |
|
logger.info("MAP@{}: {:.4f}".format(k, scores['map@k'][k])) |
|
|
|
|
|
@staticmethod |
|
def compute_dcg_at_k(relevances, k): |
|
dcg = 0 |
|
for i in range(min(len(relevances), k)): |
|
dcg += relevances[i] / np.log2(i + 2) |
|
return dcg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|