import os
import ujson
import torch
import random
from collections import defaultdict, OrderedDict
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.utils.utils import print_message, load_checkpoint
from colbert.evaluation.load_model import load_model
from colbert.utils.runs import Run
def load_queries(queries_path):
queries = OrderedDict()
print_message("#> Loading the queries from", queries_path, "...")
with open(queries_path) as f:
for line in f:
qid, query, *_ = line.strip().split('\t')
qid = int(qid)
assert (qid not in queries), ("Query QID", qid, "is repeated!")
queries[qid] = query
print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
return queries
def load_qrels(qrels_path):
if qrels_path is None:
return None
print_message("#> Loading qrels from", qrels_path, "...")
qrels = OrderedDict()
with open(qrels_path, mode='r', encoding="utf-8") as f:
for line in f:
qid, x, pid, y = map(int, line.strip().split('\t'))
assert x == 0 and y == 1
qrels[qid] = qrels.get(qid, [])
assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
print_message("#> Loaded qrels for", len(qrels), "unique queries with",
avg_positive, "positives per query on average.\n")
return qrels
def load_topK(topK_path):
queries = OrderedDict()
topK_docs = OrderedDict()
topK_pids = OrderedDict()
print_message("#> Loading the top-k per query from", topK_path, "...")
with open(topK_path) as f:
for line_idx, line in enumerate(f):
if line_idx and line_idx % (10*1000*1000) == 0:
print(line_idx, end=' ', flush=True)
qid, pid, query, passage = line.split('\t')
qid, pid = int(qid), int(pid)
assert (qid not in queries) or (queries[qid] == query)
queries[qid] = query
topK_docs[qid] = topK_docs.get(qid, [])
topK_pids[qid] = topK_pids.get(qid, [])
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
Ks = [len(topK_pids[qid]) for qid in topK_pids]
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
return queries, topK_docs, topK_pids
def load_topK_pids(topK_path, qrels):
topK_pids = defaultdict(list)
topK_positives = defaultdict(list)
print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
with open(topK_path) as f:
for line_idx, line in enumerate(f):
if line_idx and line_idx % (10*1000*1000) == 0:
print(line_idx, end=' ', flush=True)
qid, pid, *rest = line.strip().split('\t')
qid, pid = int(qid), int(pid)
assert len(rest) in [1, 2, 3]
if len(rest) > 1:
*_, label = rest
label = int(label)
assert label in [0, 1]
if label >= 1:
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
Ks = [len(topK_pids[qid]) for qid in topK_pids]
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
if len(topK_positives) == 0:
topK_positives = None
assert len(topK_pids) >= len(topK_positives)
for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
topK_positives[qid] = []
assert len(topK_pids) == len(topK_positives)
avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
avg_positive, "positives per query on average.\n")
assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
if topK_positives is None:
topK_positives = qrels
return topK_pids, topK_positives
def load_collection(collection_path):
print_message("#> Loading collection...")
collection = []
with open(collection_path) as f:
for line_idx, line in enumerate(f):
if line_idx % (1000*1000) == 0:
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
pid, passage, *rest = line.strip().split('\t')
assert pid == 'id' or int(pid) == line_idx
if len(rest) >= 1:
title = rest[0]
passage = title + ' | ' + passage
return collection
def load_colbert(args, do_print=True):
colbert, checkpoint = load_model(args, do_print)
for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
if 'arguments' in checkpoint and hasattr(args, k):
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
a, b = checkpoint['arguments'][k], getattr(args, k)
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
if 'arguments' in checkpoint:
if args.rank < 1:
print(ujson.dumps(checkpoint['arguments'], indent=4))
if do_print:
return colbert, checkpoint