import os import random import time import torch import torch.nn as nn from itertools import accumulate from math import ceil from colbert.utils.runs import Run from colbert.utils.utils import print_message from colbert.evaluation.metrics import Metrics from colbert.evaluation.ranking_logger import RankingLogger from colbert.modeling.inference import ModelInference from colbert.evaluation.slow import slow_rerank def evaluate(args): args.inference = ModelInference(args.colbert, amp=args.amp) qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids depth = args.depth collection = args.collection if collection is None: topK_docs = args.topK_docs def qid2passages(qid): if collection is not None: return [collection[pid] for pid in topK_pids[qid][:depth]] else: return topK_docs[qid][:depth] metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000}, success_depths={5, 10, 20, 50, 100, 1000}, total_queries=len(queries)) ranking_logger = RankingLogger(Run.path, qrels=qrels) args.milliseconds = [] with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger: with torch.no_grad(): keys = sorted(list(queries.keys())) random.shuffle(keys) for query_idx, qid in enumerate(keys): query = queries[qid] print_message(query_idx, qid, query, '\n') if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0: continue ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid)) rlogger.log(qid, ranking, [0, 1]) if qrels: metrics.add(query_idx, qid, ranking, qrels[qid]) for i, (score, pid, passage) in enumerate(ranking): if pid in qrels[qid]: print("\n#> Found", pid, "at position", i+1, "with score", score) print(passage) break metrics.print_metrics(query_idx) metrics.log(query_idx) print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n') print("rlogger.filename =", rlogger.filename) if len(args.milliseconds) > 1: print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:])) print("\n\n") print("\n\n") # print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:])) print("\n\n") print('\n\n') if qrels: assert query_idx + 1 == len(keys) == len(set(keys)) metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries)) print('\n\n')