diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..d4e305aa7672be93b2ae387007367e89dc053f2b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.dnn filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9e5d27a98dd99b362c8420c723276cf5303e1beb --- /dev/null +++ b/README.md @@ -0,0 +1,159 @@ +# ColBERT + +### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds. + + +

+ +

+

+ Figure 1: ColBERT's late interaction, efficiently scoring the fine-grained similarity between a queries and a passage. +

+ +As Figure 1 illustrates, ColBERT relies on fine-grained **contextual late interaction**: it encodes each passage into a **matrix** of token-level embeddings (shown above in blue). Then at search time, it embeds every query into another matrix (shown in green) and efficiently finds passages that contextually match the query using scalable vector-similarity (`MaxSim`) operators. + +These rich interactions allow ColBERT to surpass the quality of _single-vector_ representation models, while scaling efficiently to large corpora. You can read more in our papers: + +* [**ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT**](https://arxiv.org/abs/2004.12832) (SIGIR'20). +* [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21; to appear). + + +---- + +## Installation + +ColBERT (currently: [v0.2.0](#releases)) requires Python 3.7+ and Pytorch 1.6+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library. + +We strongly recommend creating a conda environment using: + +``` +conda env create -f conda_env.yml +conda activate colbert-v0.2 +``` + +If you face any problems, please [open a new issue](https://github.com/stanford-futuredata/ColBERT/issues) and we'll help you promptly! + + +## Overview + +Using ColBERT on a dataset typically involves the following steps. + +**Step 0: Preprocess your collection.** At its simplest, ColBERT works with tab-separated (TSV) files: a file (e.g., `collection.tsv`) will contain all passages and another (e.g., `queries.tsv`) will contain a set of queries for searching the collection. + +**Step 1: Train a ColBERT model.** You can [train your own ColBERT model](#training) and [validate performance](#validation) on a suitable development set. + +**Step 2: Index your collection.** Once you're happy with your ColBERT model, you need to [index your collection](#indexing) to permit fast retrieval. This step encodes all passages into matrices, stores them on disk, and builds data structures for efficient search. + +**Step 3: Search the collection with your queries.** Given your model and index, you can [issue queries over the collection](#retrieval) to retrieve the top-k passages for each query. + +Below, we illustrate these steps via an example run on the MS MARCO Passage Ranking task. + + +## Data + +This repository works directly with a simple **tab-separated file** format to store queries, passages, and top-k ranked lists. + + +* Queries: each line is `qid \t query text`. +* Collection: each line is `pid \t passage text`. +* Top-k Ranking: each line is `qid \t pid \t rank`. + +This works directly with the data format of the [MS MARCO Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) dataset. You will need the training triples (`triples.train.small.tar.gz`), the official top-1000 ranked lists for the dev set queries (`top1000.dev`), and the dev set relevant passages (`qrels.dev.small.tsv`). For indexing the full collection, you will also need the list of passages (`collection.tar.gz`). + + + +## Training + +Training requires a list of __ tab-separated triples. + +You can supply **full-text** triples, where each line is `query text \t positive passage text \t negative passage text`. Alternatively, you can supply the query and passage **IDs** as a JSONL file `[qid, pid+, pid-]` per line, in which case you should specify `--collection path/to/collection.tsv` and `--queries path/to/queries.train.tsv`. + + +``` +CUDA_VISIBLE_DEVICES="0,1,2,3" \ +python -m torch.distributed.launch --nproc_per_node=4 -m \ +colbert.train --amp --doc_maxlen 180 --mask-punctuation --bsize 32 --accum 1 \ +--triples /path/to/MSMARCO/triples.train.small.tsv \ +--root /root/to/experiments/ --experiment MSMARCO-psg --similarity l2 --run msmarco.psg.l2 +``` + +You can use one or more GPUs by modifying `CUDA_VISIBLE_DEVICES` and `--nproc_per_node`. + + +## Validation + +Before indexing into ColBERT, you can compare a few checkpoints by re-ranking a top-k set of documents per query. This will use ColBERT _on-the-fly_: it will compute document representations _during_ query evaluation. + +This script requires the top-k list per query, provided as a tab-separated file whose every line contains a tuple `queryID \t passageID \t rank`, where rank is {1, 2, 3, ...} for each query. The script also accepts the format of MS MARCO's `top1000.dev` and `top1000.eval` and you can optionally supply relevance judgements (qrels) for evaluation. This is a tab-separated file whose every line has a quadruple __, like `qrels.dev.small.tsv`. + +Example command: + +``` +python -m colbert.test --amp --doc_maxlen 180 --mask-punctuation \ +--collection /path/to/MSMARCO/collection.tsv \ +--queries /path/to/MSMARCO/queries.dev.small.tsv \ +--topk /path/to/MSMARCO/top1000.dev \ +--checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \ +--root /root/to/experiments/ --experiment MSMARCO-psg [--qrels path/to/qrels.dev.small.tsv] +``` + + +## Indexing + +For fast retrieval, indexing precomputes the ColBERT representations of passages. + +Example command: + +``` +CUDA_VISIBLE_DEVICES="0,1,2,3" OMP_NUM_THREADS=6 \ +python -m torch.distributed.launch --nproc_per_node=4 -m \ +colbert.index --amp --doc_maxlen 180 --mask-punctuation --bsize 256 \ +--checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \ +--collection /path/to/MSMARCO/collection.tsv \ +--index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \ +--root /root/to/experiments/ --experiment MSMARCO-psg +``` + +The index created here allows you to re-rank the top-k passages retrieved by another method (e.g., BM25). + +We typically recommend that you use ColBERT for **end-to-end** retrieval, where it directly finds its top-k passages from the full collection. For this, you need FAISS indexing. + + +#### FAISS Indexing for end-to-end retrieval + +For end-to-end retrieval, you should index the document representations into [FAISS](https://github.com/facebookresearch/faiss). + +``` +python -m colbert.index_faiss \ +--index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \ +--partitions 32768 --sample 0.3 \ +--root /root/to/experiments/ --experiment MSMARCO-psg +``` + + +## Retrieval + +In the simplest case, you want to retrieve from the full collection: + +``` +python -m colbert.retrieve \ +--amp --doc_maxlen 180 --mask-punctuation --bsize 256 \ +--queries /path/to/MSMARCO/queries.dev.small.tsv +--nprobe 32 --partitions 32768 --faiss_depth 1024 \ +--index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \ +--checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \ +--root /root/to/experiments/ --experiment MSMARCO-psg +``` + +You may also want to re-rank a top-k set that you've retrieved before with ColBERT or with another model. For this, use `colbert.rerank` similarly and additionally pass `--topk`. + +If you have a large set of queries (or want to reduce memory usage), use **batch-mode** retrieval and/or re-ranking. This can be done by passing `--batch --only_retrieval` to `colbert.retrieve` and passing `--batch --log-scores` to colbert.rerank alongside `--topk` with the `unordered.tsv` output of this retrieval run. + +Some use cases (e.g., building a user-facing search engines) require more control over retrieval. For those, you typically don't want to use the command line for retrieval. Instead, you want to import our retrieval API from Python and directly work with that (e.g., to build a simple REST API). Instructions for this are coming soon, but you will just need to adapt/modify the retrieval loop in [`colbert/ranking/retrieval.py#L33`](colbert/ranking/retrieval.py#L33). + + +## Releases + +* v0.2.0: Sep 2020 +* v0.1.0: June 2020 + diff --git a/colbert/__init__.py b/colbert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/__pycache__/__init__.cpython-37.pyc b/colbert/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4eb90ef5d5a6c2254af7ac1c7375184a0b1f019 Binary files /dev/null and b/colbert/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/__pycache__/index.cpython-37.pyc b/colbert/__pycache__/index.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b13bba8cbfa3e60acc6567edfb8eb7413092200f Binary files /dev/null and b/colbert/__pycache__/index.cpython-37.pyc differ diff --git a/colbert/__pycache__/index_faiss.cpython-37.pyc b/colbert/__pycache__/index_faiss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aaed04ba1a43eaed3b93c6cbfd696c9a869bbd3 Binary files /dev/null and b/colbert/__pycache__/index_faiss.cpython-37.pyc differ diff --git a/colbert/__pycache__/parameters.cpython-37.pyc b/colbert/__pycache__/parameters.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eead41d1892b3944e741f85860ed8eb629838675 Binary files /dev/null and b/colbert/__pycache__/parameters.cpython-37.pyc differ diff --git a/colbert/__pycache__/retrieve.cpython-37.pyc b/colbert/__pycache__/retrieve.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efba785300c92482f7907c500c0df73adca959da Binary files /dev/null and b/colbert/__pycache__/retrieve.cpython-37.pyc differ diff --git a/colbert/__pycache__/train.cpython-37.pyc b/colbert/__pycache__/train.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638ad60a79f2fbe78216047d6f71c2338ebae585 Binary files /dev/null and b/colbert/__pycache__/train.cpython-37.pyc differ diff --git a/colbert/evaluation/__init__.py b/colbert/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/evaluation/__pycache__/__init__.cpython-37.pyc b/colbert/evaluation/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0919c75c276fe3ed683d2b0ae14b9ffce7df1de9 Binary files /dev/null and b/colbert/evaluation/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/evaluation/__pycache__/load_model.cpython-37.pyc b/colbert/evaluation/__pycache__/load_model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af0c30ed903bbb3607d687c317d4330f56f4728 Binary files /dev/null and b/colbert/evaluation/__pycache__/load_model.cpython-37.pyc differ diff --git a/colbert/evaluation/__pycache__/loaders.cpython-37.pyc b/colbert/evaluation/__pycache__/loaders.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c14966be8fca94c6750ef2b607bcc105a8e44de3 Binary files /dev/null and b/colbert/evaluation/__pycache__/loaders.cpython-37.pyc differ diff --git a/colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc b/colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66296a6549c905407e74b55e2c9c11e8b4017a68 Binary files /dev/null and b/colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc differ diff --git a/colbert/evaluation/load_model.py b/colbert/evaluation/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..87410e991d5d96890f1106944f6fbe6de98899a8 --- /dev/null +++ b/colbert/evaluation/load_model.py @@ -0,0 +1,28 @@ +import os +import ujson +import torch +import random + +from collections import defaultdict, OrderedDict + +from colbert.parameters import DEVICE +from colbert.modeling.colbert import ColBERT +from colbert.utils.utils import print_message, load_checkpoint + + +def load_model(args, do_print=True): + colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased', + query_maxlen=args.query_maxlen, + doc_maxlen=args.doc_maxlen, + dim=args.dim, + similarity_metric=args.similarity, + mask_punctuation=args.mask_punctuation) + colbert = colbert.to(DEVICE) + + print_message("#> Loading model checkpoint.", condition=do_print) + + checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) + + colbert.eval() + + return colbert, checkpoint diff --git a/colbert/evaluation/loaders.py b/colbert/evaluation/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..09252f491a76294857284e12ea060789eb189a3a --- /dev/null +++ b/colbert/evaluation/loaders.py @@ -0,0 +1,196 @@ +import os +import ujson +import torch +import random + +from collections import defaultdict, OrderedDict + +from colbert.parameters import DEVICE +from colbert.modeling.colbert import ColBERT +from colbert.utils.utils import print_message, load_checkpoint +from colbert.evaluation.load_model import load_model +from colbert.utils.runs import Run + + +def load_queries(queries_path): + queries = OrderedDict() + + print_message("#> Loading the queries from", queries_path, "...") + + with open(queries_path) as f: + for line in f: + qid, query, *_ = line.strip().split('\t') + qid = int(qid) + + assert (qid not in queries), ("Query QID", qid, "is repeated!") + queries[qid] = query + + print_message("#> Got", len(queries), "queries. All QIDs are unique.\n") + + return queries + + +def load_qrels(qrels_path): + if qrels_path is None: + return None + + print_message("#> Loading qrels from", qrels_path, "...") + + qrels = OrderedDict() + with open(qrels_path, mode='r', encoding="utf-8") as f: + for line in f: + qid, x, pid, y = map(int, line.strip().split('\t')) + assert x == 0 and y == 1 + qrels[qid] = qrels.get(qid, []) + qrels[qid].append(pid) + + assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels) + + avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2) + + print_message("#> Loaded qrels for", len(qrels), "unique queries with", + avg_positive, "positives per query on average.\n") + + return qrels + + +def load_topK(topK_path): + queries = OrderedDict() + topK_docs = OrderedDict() + topK_pids = OrderedDict() + + print_message("#> Loading the top-k per query from", topK_path, "...") + + with open(topK_path) as f: + for line_idx, line in enumerate(f): + if line_idx and line_idx % (10*1000*1000) == 0: + print(line_idx, end=' ', flush=True) + + qid, pid, query, passage = line.split('\t') + qid, pid = int(qid), int(pid) + + assert (qid not in queries) or (queries[qid] == query) + queries[qid] = query + topK_docs[qid] = topK_docs.get(qid, []) + topK_docs[qid].append(passage) + topK_pids[qid] = topK_pids.get(qid, []) + topK_pids[qid].append(pid) + + print() + + assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) + + Ks = [len(topK_pids[qid]) for qid in topK_pids] + + print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) + print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n") + + return queries, topK_docs, topK_pids + + +def load_topK_pids(topK_path, qrels): + topK_pids = defaultdict(list) + topK_positives = defaultdict(list) + + print_message("#> Loading the top-k PIDs per query from", topK_path, "...") + + with open(topK_path) as f: + for line_idx, line in enumerate(f): + if line_idx and line_idx % (10*1000*1000) == 0: + print(line_idx, end=' ', flush=True) + + qid, pid, *rest = line.strip().split('\t') + qid, pid = int(qid), int(pid) + + topK_pids[qid].append(pid) + + assert len(rest) in [1, 2, 3] + + if len(rest) > 1: + *_, label = rest + label = int(label) + assert label in [0, 1] + + if label >= 1: + topK_positives[qid].append(pid) + + print() + + assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) + assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives) + + # Make them sets for fast lookups later + topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives} + + Ks = [len(topK_pids[qid]) for qid in topK_pids] + + print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) + print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n") + + if len(topK_positives) == 0: + topK_positives = None + else: + assert len(topK_pids) >= len(topK_positives) + + for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())): + topK_positives[qid] = [] + + assert len(topK_pids) == len(topK_positives) + + avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2) + + print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with", + avg_positive, "positives per query on average.\n") + + assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!" + + if topK_positives is None: + topK_positives = qrels + + return topK_pids, topK_positives + + +def load_collection(collection_path): + print_message("#> Loading collection...") + + collection = [] + + with open(collection_path) as f: + for line_idx, line in enumerate(f): + if line_idx % (1000*1000) == 0: + print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True) + + pid, passage, *rest = line.strip().split('\t') + assert pid == 'id' or int(pid) == line_idx + + if len(rest) >= 1: + title = rest[0] + passage = title + ' | ' + passage + + collection.append(passage) + + print() + + return collection + + +def load_colbert(args, do_print=True): + colbert, checkpoint = load_model(args, do_print) + + # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used. + # I.e., not their purely (i.e., training) default values. + + for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: + if 'arguments' in checkpoint and hasattr(args, k): + if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k): + a, b = checkpoint['arguments'][k], getattr(args, k) + Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})") + + if 'arguments' in checkpoint: + if args.rank < 1: + print(ujson.dumps(checkpoint['arguments'], indent=4)) + + if do_print: + print('\n') + + return colbert, checkpoint diff --git a/colbert/evaluation/metrics.py b/colbert/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..f7582626fd512f9400032ef61ac7d79ece6f49c3 --- /dev/null +++ b/colbert/evaluation/metrics.py @@ -0,0 +1,114 @@ +import ujson + +from collections import defaultdict +from colbert.utils.runs import Run + + +class Metrics: + def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None): + self.results = {} + self.mrr_sums = {depth: 0.0 for depth in mrr_depths} + self.recall_sums = {depth: 0.0 for depth in recall_depths} + self.success_sums = {depth: 0.0 for depth in success_depths} + self.total_queries = total_queries + + self.max_query_idx = -1 + self.num_queries_added = 0 + + def add(self, query_idx, query_key, ranking, gold_positives): + self.num_queries_added += 1 + + assert query_key not in self.results + assert len(self.results) <= query_idx + assert len(set(gold_positives)) == len(gold_positives) + assert len(set([pid for _, pid, _ in ranking])) == len(ranking) + + self.results[query_key] = ranking + + positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives] + + if len(positives) == 0: + return + + for depth in self.mrr_sums: + first_positive = positives[0] + self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0 + + for depth in self.success_sums: + first_positive = positives[0] + self.success_sums[depth] += 1.0 if first_positive < depth else 0.0 + + for depth in self.recall_sums: + num_positives_up_to_depth = len([pos for pos in positives if pos < depth]) + self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives) + + def print_metrics(self, query_idx): + for depth in sorted(self.mrr_sums): + print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0)) + + for depth in sorted(self.success_sums): + print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0)) + + for depth in sorted(self.recall_sums): + print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0)) + + def log(self, query_idx): + assert query_idx >= self.max_query_idx + self.max_query_idx = query_idx + + Run.log_metric("ranking/max_query_idx", query_idx, query_idx) + Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx) + + for depth in sorted(self.mrr_sums): + score = self.mrr_sums[depth] / (query_idx+1.0) + Run.log_metric("ranking/MRR." + str(depth), score, query_idx) + + for depth in sorted(self.success_sums): + score = self.success_sums[depth] / (query_idx+1.0) + Run.log_metric("ranking/Success." + str(depth), score, query_idx) + + for depth in sorted(self.recall_sums): + score = self.recall_sums[depth] / (query_idx+1.0) + Run.log_metric("ranking/Recall." + str(depth), score, query_idx) + + def output_final_metrics(self, path, query_idx, num_queries): + assert query_idx + 1 == num_queries + assert num_queries == self.total_queries + + if self.max_query_idx < query_idx: + self.log(query_idx) + + self.print_metrics(query_idx) + + output = defaultdict(dict) + + for depth in sorted(self.mrr_sums): + score = self.mrr_sums[depth] / (query_idx+1.0) + output['mrr'][depth] = score + + for depth in sorted(self.success_sums): + score = self.success_sums[depth] / (query_idx+1.0) + output['success'][depth] = score + + for depth in sorted(self.recall_sums): + score = self.recall_sums[depth] / (query_idx+1.0) + output['recall'][depth] = score + + with open(path, 'w') as f: + ujson.dump(output, f, indent=4) + f.write('\n') + + +def evaluate_recall(qrels, queries, topK_pids): + if qrels is None: + return + + assert set(qrels.keys()) == set(queries.keys()) + recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid])) + for qid in qrels] + recall_at_k = sum(recall_at_k) / len(qrels) + recall_at_k = round(recall_at_k, 3) + print("Recall @ maximum depth =", recall_at_k) + + +# TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output. diff --git a/colbert/evaluation/ranking.py b/colbert/evaluation/ranking.py new file mode 100644 index 0000000000000000000000000000000000000000..f57a42f7b2092b0dcca35414c07f97c262cf1956 --- /dev/null +++ b/colbert/evaluation/ranking.py @@ -0,0 +1,88 @@ +import os +import random +import time +import torch +import torch.nn as nn + +from itertools import accumulate +from math import ceil + +from colbert.utils.runs import Run +from colbert.utils.utils import print_message + +from colbert.evaluation.metrics import Metrics +from colbert.evaluation.ranking_logger import RankingLogger +from colbert.modeling.inference import ModelInference + +from colbert.evaluation.slow import slow_rerank + + +def evaluate(args): + args.inference = ModelInference(args.colbert, amp=args.amp) + qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids + + depth = args.depth + collection = args.collection + if collection is None: + topK_docs = args.topK_docs + + def qid2passages(qid): + if collection is not None: + return [collection[pid] for pid in topK_pids[qid][:depth]] + else: + return topK_docs[qid][:depth] + + metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000}, + success_depths={5, 10, 20, 50, 100, 1000}, + total_queries=len(queries)) + + ranking_logger = RankingLogger(Run.path, qrels=qrels) + + args.milliseconds = [] + + with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger: + with torch.no_grad(): + keys = sorted(list(queries.keys())) + random.shuffle(keys) + + for query_idx, qid in enumerate(keys): + query = queries[qid] + + print_message(query_idx, qid, query, '\n') + + if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0: + continue + + ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid)) + + rlogger.log(qid, ranking, [0, 1]) + + if qrels: + metrics.add(query_idx, qid, ranking, qrels[qid]) + + for i, (score, pid, passage) in enumerate(ranking): + if pid in qrels[qid]: + print("\n#> Found", pid, "at position", i+1, "with score", score) + print(passage) + break + + metrics.print_metrics(query_idx) + metrics.log(query_idx) + + print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n') + print("rlogger.filename =", rlogger.filename) + + if len(args.milliseconds) > 1: + print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:])) + + print("\n\n") + + print("\n\n") + # print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:])) + print("\n\n") + + print('\n\n') + if qrels: + assert query_idx + 1 == len(keys) == len(set(keys)) + metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries)) + print('\n\n') diff --git a/colbert/evaluation/ranking_logger.py b/colbert/evaluation/ranking_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..2bd48ac90b1fa09f3b028e39f1623fa07a82ccdf --- /dev/null +++ b/colbert/evaluation/ranking_logger.py @@ -0,0 +1,57 @@ +import os + +from contextlib import contextmanager +from colbert.utils.utils import print_message, NullContextManager +from colbert.utils.runs import Run + + +class RankingLogger(): + def __init__(self, directory, qrels=None, log_scores=False): + self.directory = directory + self.qrels = qrels + self.filename, self.also_save_annotations = None, None + self.log_scores = log_scores + + @contextmanager + def context(self, filename, also_save_annotations=False): + assert self.filename is None + assert self.also_save_annotations is None + + filename = os.path.join(self.directory, filename) + self.filename, self.also_save_annotations = filename, also_save_annotations + + print_message("#> Logging ranked lists to {}".format(self.filename)) + + with open(filename, 'w') as f: + self.f = f + with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g: + self.g = g + try: + yield self + finally: + pass + + def log(self, qid, ranking, is_ranked=True, print_positions=[]): + print_positions = set(print_positions) + + f_buffer = [] + g_buffer = [] + + for rank, (score, pid, passage) in enumerate(ranking): + is_relevant = self.qrels and int(pid in self.qrels[qid]) + rank = rank+1 if is_ranked else -1 + + possibly_score = [score] if self.log_scores else [] + + f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n") + if self.g: + g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n") + + if rank in print_positions: + prefix = "** " if is_relevant else "" + prefix += str(rank) + print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, ' ', passage) + + self.f.write(''.join(f_buffer)) + if self.g: + self.g.write(''.join(g_buffer)) diff --git a/colbert/evaluation/slow.py b/colbert/evaluation/slow.py new file mode 100644 index 0000000000000000000000000000000000000000..5a094c3b0cc081a423c830871abf100a4b4a0e23 --- /dev/null +++ b/colbert/evaluation/slow.py @@ -0,0 +1,21 @@ +import os + +def slow_rerank(args, query, pids, passages): + colbert = args.colbert + inference = args.inference + + Q = inference.queryFromText([query]) + + D_ = inference.docFromText(passages, bsize=args.bsize) + scores = colbert.score(Q, D_).cpu() + + scores = scores.sort(descending=True) + ranked = scores.indices.tolist() + + ranked_scores = scores.values.tolist() + ranked_pids = [pids[position] for position in ranked] + ranked_passages = [passages[position] for position in ranked] + + assert len(ranked_pids) == len(set(ranked_pids)) + + return list(zip(ranked_scores, ranked_pids, ranked_passages)) diff --git a/colbert/index.py b/colbert/index.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2c9c15fe395cf15542dc3ecc6bcf450d6ef6d6 --- /dev/null +++ b/colbert/index.py @@ -0,0 +1,59 @@ +import os +import ujson +import random + +from colbert.utils.runs import Run +from colbert.utils.parser import Arguments +import colbert.utils.distributed as distributed + +from colbert.utils.utils import print_message, create_directory +from colbert.indexing.encoder import CollectionEncoder + + +def main(): + random.seed(12345) + + parser = Arguments(description='Precomputing document representations with ColBERT.') + + parser.add_model_parameters() + parser.add_model_inference_parameters() + parser.add_indexing_input() + + parser.add_argument('--chunksize', dest='chunksize', default=6.0, required=False, type=float) # in GiBs + + args = parser.parse() + + with Run.context(): + args.index_path = os.path.join(args.index_root, args.index_name) + assert not os.path.exists(args.index_path), args.index_path + + distributed.barrier(args.rank) + + if args.rank < 1: + create_directory(args.index_root) + create_directory(args.index_path) + + distributed.barrier(args.rank) + + process_idx = max(0, args.rank) + encoder = CollectionEncoder(args, process_idx=process_idx, num_processes=args.nranks) + encoder.encode() + + distributed.barrier(args.rank) + + # Save metadata. + if args.rank < 1: + metadata_path = os.path.join(args.index_path, 'metadata.json') + print_message("Saving (the following) metadata to", metadata_path, "..") + print(args.input_arguments) + + with open(metadata_path, 'w') as output_metadata: + ujson.dump(args.input_arguments.__dict__, output_metadata) + + distributed.barrier(args.rank) + + +if __name__ == "__main__": + main() + +# TODO: Add resume functionality diff --git a/colbert/index_faiss.py b/colbert/index_faiss.py new file mode 100644 index 0000000000000000000000000000000000000000..f22b0863cedb185918de3355175d3f37c1d9cb7c --- /dev/null +++ b/colbert/index_faiss.py @@ -0,0 +1,43 @@ +import os +import random +import math + +from colbert.utils.runs import Run +from colbert.utils.parser import Arguments +from colbert.indexing.faiss import index_faiss +from colbert.indexing.loaders import load_doclens + + +def main(): + random.seed(12345) + + parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.') + parser.add_index_use_input() + + parser.add_argument('--sample', dest='sample', default=None, type=float) + parser.add_argument('--slices', dest='slices', default=1, type=int) + + args = parser.parse() + assert args.slices >= 1 + assert args.sample is None or (0.0 < args.sample < 1.0), args.sample + + with Run.context(): + args.index_path = os.path.join(args.index_root, args.index_name) + assert os.path.exists(args.index_path), args.index_path + + num_embeddings = sum(load_doclens(args.index_path)) + print("#> num_embeddings =", num_embeddings) + + if args.partitions is None: + args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings))) + print('\n\n') + Run.warn("You did not specify --partitions!") + Run.warn("Default computation chooses", args.partitions, + "partitions (for {} embeddings)".format(num_embeddings)) + print('\n\n') + + index_faiss(args) + + +if __name__ == "__main__": + main() diff --git a/colbert/indexing/__init__.py b/colbert/indexing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/indexing/__pycache__/__init__.cpython-37.pyc b/colbert/indexing/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..913a2398f3f5d77038746c2b8441e0b70a7c52e6 Binary files /dev/null and b/colbert/indexing/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/indexing/__pycache__/encoder.cpython-37.pyc b/colbert/indexing/__pycache__/encoder.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..676a5eda6c45dc823200338daecfc8236845418b Binary files /dev/null and b/colbert/indexing/__pycache__/encoder.cpython-37.pyc differ diff --git a/colbert/indexing/__pycache__/faiss.cpython-37.pyc b/colbert/indexing/__pycache__/faiss.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09e19655f9c16442f352b8ccd7723d874f9e1c2 Binary files /dev/null and b/colbert/indexing/__pycache__/faiss.cpython-37.pyc differ diff --git a/colbert/indexing/__pycache__/faiss_index.cpython-37.pyc b/colbert/indexing/__pycache__/faiss_index.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..551bb2f48383cc7c38b1c8eb10d6b9b356897b01 Binary files /dev/null and b/colbert/indexing/__pycache__/faiss_index.cpython-37.pyc differ diff --git a/colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc b/colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7657febf54052bc2ce8eac0d79a053d8e81eab6e Binary files /dev/null and b/colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc differ diff --git a/colbert/indexing/__pycache__/index_manager.cpython-37.pyc b/colbert/indexing/__pycache__/index_manager.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..281587436c165e27a842f1daa4a1144093b83254 Binary files /dev/null and b/colbert/indexing/__pycache__/index_manager.cpython-37.pyc differ diff --git a/colbert/indexing/__pycache__/loaders.cpython-37.pyc b/colbert/indexing/__pycache__/loaders.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53e456572f32585dcef1df44e3f7ff9388ab2fb6 Binary files /dev/null and b/colbert/indexing/__pycache__/loaders.cpython-37.pyc differ diff --git a/colbert/indexing/encoder.py b/colbert/indexing/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..92bc77e3cbab3e1599f7e018afef7bd956cbdd9c --- /dev/null +++ b/colbert/indexing/encoder.py @@ -0,0 +1,187 @@ +import os +import time +import torch +import ujson +import numpy as np + +import itertools +import threading +import queue + +from colbert.modeling.inference import ModelInference +from colbert.evaluation.loaders import load_colbert +from colbert.utils.utils import print_message + +from colbert.indexing.index_manager import IndexManager + + +class CollectionEncoder(): + def __init__(self, args, process_idx, num_processes): + self.args = args + self.collection = args.collection + self.process_idx = process_idx + self.num_processes = num_processes + + assert 0.5 <= args.chunksize <= 128.0 + max_bytes_per_file = args.chunksize * (1024*1024*1024) + + max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0) + + # Determine subset sizes for output + minimum_subset_size = 10_000 + maximum_subset_size = max_bytes_per_file / max_bytes_per_doc + maximum_subset_size = max(minimum_subset_size, maximum_subset_size) + self.possible_subset_sizes = [int(maximum_subset_size)] + + self.print_main("#> Local args.bsize =", args.bsize) + self.print_main("#> args.index_root =", args.index_root) + self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}") + + self._load_model() + self.indexmgr = IndexManager(args.dim) + self.iterator = self._initialize_iterator() + + def _initialize_iterator(self): + return open(self.collection) + + def _saver_thread(self): + for args in iter(self.saver_queue.get, None): + self._save_batch(*args) + + def _load_model(self): + self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0)) + self.colbert = self.colbert.cuda() + self.colbert.eval() + + self.inference = ModelInference(self.colbert, amp=self.args.amp) + + def encode(self): + self.saver_queue = queue.Queue(maxsize=3) + thread = threading.Thread(target=self._saver_thread) + thread.start() + + t0 = time.time() + local_docs_processed = 0 + + for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)): + if owner != self.process_idx: + continue + + t1 = time.time() + batch = self._preprocess_batch(offset, lines) + embs, doclens = self._encode_batch(batch_idx, batch) + + t2 = time.time() + self.saver_queue.put((batch_idx, embs, offset, doclens)) + + t3 = time.time() + local_docs_processed += len(lines) + overall_throughput = compute_throughput(local_docs_processed, t0, t3) + this_encoding_throughput = compute_throughput(len(lines), t1, t2) + this_saving_throughput = compute_throughput(len(lines), t2, t3) + + self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t' + f'Passages/min: {overall_throughput} (overall), ', + f'{this_encoding_throughput} (this encoding), ', + f'{this_saving_throughput} (this saving)') + self.saver_queue.put(None) + + self.print("#> Joining saver thread.") + thread.join() + + def _batch_passages(self, fi): + """ + Must use the same seed across processes! + """ + np.random.seed(0) + + offset = 0 + for owner in itertools.cycle(range(self.num_processes)): + batch_size = np.random.choice(self.possible_subset_sizes) + + L = [line for _, line in zip(range(batch_size), fi)] + + if len(L) == 0: + break # EOF + + yield (offset, L, owner) + offset += len(L) + + if len(L) < batch_size: + break # EOF + + self.print("[NOTE] Done with local share.") + + return + + def _preprocess_batch(self, offset, lines): + endpos = offset + len(lines) + + batch = [] + + for line_idx, line in zip(range(offset, endpos), lines): + line_parts = line.strip().split('\t') + + pid, passage, *other = line_parts + + assert len(passage) >= 1 + + if len(other) >= 1: + title, *_ = other + passage = title + ' | ' + passage + + batch.append(passage) + + # assert pid == 'id' or int(pid) == line_idx + + return batch + + def _encode_batch(self, batch_idx, batch): + with torch.no_grad(): + embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False) + assert type(embs) is list + assert len(embs) == len(batch) + + local_doclens = [d.size(0) for d in embs] + embs = torch.cat(embs) + + return embs, local_doclens + + def _save_batch(self, batch_idx, embs, offset, doclens): + start_time = time.time() + + output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx)) + output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx)) + doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx)) + + # Save the embeddings. + self.indexmgr.save(embs, output_path) + self.indexmgr.save(embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20,))], output_sample_path) + + # Save the doclens. + with open(doclens_path, 'w') as output_doclens: + ujson.dump(doclens, output_doclens) + + throughput = compute_throughput(len(doclens), start_time, time.time()) + self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path), + "Saving Throughput =", throughput, "passages per minute.\n") + + def print(self, *args): + print_message("[" + str(self.process_idx) + "]", "\t\t", *args) + + def print_main(self, *args): + if self.process_idx == 0: + self.print(*args) + + +def compute_throughput(size, t0, t1): + throughput = size / (t1 - t0) * 60 + + if throughput > 1000 * 1000: + throughput = throughput / (1000*1000) + throughput = round(throughput, 1) + return '{}M'.format(throughput) + + throughput = throughput / (1000) + throughput = round(throughput, 1) + return '{}k'.format(throughput) diff --git a/colbert/indexing/faiss.py b/colbert/indexing/faiss.py new file mode 100644 index 0000000000000000000000000000000000000000..a19670e1c60f963c9948f8353f22d3821c3676d0 --- /dev/null +++ b/colbert/indexing/faiss.py @@ -0,0 +1,116 @@ +import os +import math +import faiss +import torch +import numpy as np + +import threading +import queue + +from colbert.utils.utils import print_message, grouper +from colbert.indexing.loaders import get_parts +from colbert.indexing.index_manager import load_index_part +from colbert.indexing.faiss_index import FaissIndex + + +def get_faiss_index_name(args, offset=None, endpos=None): + partitions_info = '' if args.partitions is None else f'.{args.partitions}' + range_info = '' if offset is None else f'.{offset}-{endpos}' + + return f'ivfpq{partitions_info}{range_info}.faiss' + + +def load_sample(samples_paths, sample_fraction=None): + sample = [] + + for filename in samples_paths: + print_message(f"#> Loading {filename} ...") + part = load_index_part(filename) + if sample_fraction: + part = part[torch.randint(0, high=part.size(0), size=(int(part.size(0) * sample_fraction),))] + sample.append(part) + + sample = torch.cat(sample).float().numpy() + + print("#> Sample has shape", sample.shape) + + return sample + + +def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None): + training_sample = load_sample(slice_samples_paths, sample_fraction=sample_fraction) + + dim = training_sample.shape[-1] + index = FaissIndex(dim, partitions) + + print_message("#> Training with the vectors...") + + index.train(training_sample) + + print_message("Done training!\n") + + return index + + +SPAN = 3 + + +def index_faiss(args): + print_message("#> Starting..") + + parts, parts_paths, samples_paths = get_parts(args.index_path) + + if args.sample is not None: + assert args.sample, args.sample + print_message(f"#> Training with {round(args.sample * 100.0, 1)}% of *all* embeddings (provided --sample).") + samples_paths = parts_paths + + num_parts_per_slice = math.ceil(len(parts) / args.slices) + + for slice_idx, part_offset in enumerate(range(0, len(parts), num_parts_per_slice)): + part_endpos = min(part_offset + num_parts_per_slice, len(parts)) + + slice_parts_paths = parts_paths[part_offset:part_endpos] + slice_samples_paths = samples_paths[part_offset:part_endpos] + + if args.slices == 1: + faiss_index_name = get_faiss_index_name(args) + else: + faiss_index_name = get_faiss_index_name(args, offset=part_offset, endpos=part_endpos) + + output_path = os.path.join(args.index_path, faiss_index_name) + print_message(f"#> Processing slice #{slice_idx+1} of {args.slices} (range {part_offset}..{part_endpos}).") + print_message(f"#> Will write to {output_path}.") + + assert not os.path.exists(output_path), output_path + + index = prepare_faiss_index(slice_samples_paths, args.partitions, args.sample) + + loaded_parts = queue.Queue(maxsize=1) + + def _loader_thread(thread_parts_paths): + for filenames in grouper(thread_parts_paths, SPAN, fillvalue=None): + sub_collection = [load_index_part(filename) for filename in filenames if filename is not None] + sub_collection = torch.cat(sub_collection) + sub_collection = sub_collection.float().numpy() + loaded_parts.put(sub_collection) + + thread = threading.Thread(target=_loader_thread, args=(slice_parts_paths,)) + thread.start() + + print_message("#> Indexing the vectors...") + + for filenames in grouper(slice_parts_paths, SPAN, fillvalue=None): + print_message("#> Loading", filenames, "(from queue)...") + sub_collection = loaded_parts.get() + + print_message("#> Processing a sub_collection with shape", sub_collection.shape) + index.add(sub_collection) + + print_message("Done indexing!") + + index.save(output_path) + + print_message(f"\n\nDone! All complete (for slice #{slice_idx+1} of {args.slices})!") + + thread.join() diff --git a/colbert/indexing/faiss_index.py b/colbert/indexing/faiss_index.py new file mode 100644 index 0000000000000000000000000000000000000000..32349289adf427ec562e8735beabbe0cd9c42010 --- /dev/null +++ b/colbert/indexing/faiss_index.py @@ -0,0 +1,58 @@ +import sys +import time +import math +import faiss +import torch + +import numpy as np + +from colbert.indexing.faiss_index_gpu import FaissIndexGPU +from colbert.utils.utils import print_message + + +class FaissIndex(): + def __init__(self, dim, partitions): + self.dim = dim + self.partitions = partitions + + self.gpu = FaissIndexGPU() + self.quantizer, self.index = self._create_index() + self.offset = 0 + + def _create_index(self): + quantizer = faiss.IndexFlatL2(self.dim) # faiss.IndexHNSWFlat(dim, 32) + index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, 16, 8) + + return quantizer, index + + def train(self, train_data): + print_message(f"#> Training now (using {self.gpu.ngpu} GPUs)...") + + if self.gpu.ngpu > 0: + self.gpu.training_initialize(self.index, self.quantizer) + + s = time.time() + self.index.train(train_data) + print(time.time() - s) + + if self.gpu.ngpu > 0: + self.gpu.training_finalize() + + def add(self, data): + print_message(f"Add data with shape {data.shape} (offset = {self.offset})..") + + if self.gpu.ngpu > 0 and self.offset == 0: + self.gpu.adding_initialize(self.index) + + if self.gpu.ngpu > 0: + self.gpu.add(self.index, data, self.offset) + else: + self.index.add(data) + + self.offset += data.shape[0] + + def save(self, output_path): + print_message(f"Writing index to {output_path} ...") + + self.index.nprobe = 10 # just a default + faiss.write_index(self.index, output_path) diff --git a/colbert/indexing/faiss_index_gpu.py b/colbert/indexing/faiss_index_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..e718dd3f569791d3559bd461d25dbe9a553f0ff0 --- /dev/null +++ b/colbert/indexing/faiss_index_gpu.py @@ -0,0 +1,138 @@ +""" + Heavily based on: https://github.com/facebookresearch/faiss/blob/master/benchs/bench_gpu_1bn.py +""" + + +import sys +import time +import math +import faiss +import torch + +import numpy as np +from colbert.utils.utils import print_message + + +class FaissIndexGPU(): + def __init__(self): + self.ngpu = faiss.get_num_gpus() + + if self.ngpu == 0: + return + + self.tempmem = 1 << 33 + self.max_add_per_gpu = 1 << 25 + self.max_add = self.max_add_per_gpu * self.ngpu + self.add_batch_size = 65536 + + self.gpu_resources = self._prepare_gpu_resources() + + def _prepare_gpu_resources(self): + print_message(f"Preparing resources for {self.ngpu} GPUs.") + + gpu_resources = [] + + for _ in range(self.ngpu): + res = faiss.StandardGpuResources() + if self.tempmem >= 0: + res.setTempMemory(self.tempmem) + gpu_resources.append(res) + + return gpu_resources + + def _make_vres_vdev(self): + """ + return vectors of device ids and resources useful for gpu_multiple + """ + + assert self.ngpu > 0 + + vres = faiss.GpuResourcesVector() + vdev = faiss.IntVector() + + for i in range(self.ngpu): + vdev.push_back(i) + vres.push_back(self.gpu_resources[i]) + + return vres, vdev + + def training_initialize(self, index, quantizer): + """ + The index and quantizer should be owned by caller. + """ + + assert self.ngpu > 0 + + s = time.time() + self.index_ivf = faiss.extract_index_ivf(index) + self.clustering_index = faiss.index_cpu_to_all_gpus(quantizer) + self.index_ivf.clustering_index = self.clustering_index + print(time.time() - s) + + def training_finalize(self): + assert self.ngpu > 0 + + s = time.time() + self.index_ivf.clustering_index = faiss.index_gpu_to_cpu(self.index_ivf.clustering_index) + print(time.time() - s) + + def adding_initialize(self, index): + """ + The index should be owned by caller. + """ + + assert self.ngpu > 0 + + self.co = faiss.GpuMultipleClonerOptions() + self.co.useFloat16 = True + self.co.useFloat16CoarseQuantizer = False + self.co.usePrecomputed = False + self.co.indicesOptions = faiss.INDICES_CPU + self.co.verbose = True + self.co.reserveVecs = self.max_add + self.co.shard = True + assert self.co.shard_type in (0, 1, 2) + + self.vres, self.vdev = self._make_vres_vdev() + self.gpu_index = faiss.index_cpu_to_gpu_multiple(self.vres, self.vdev, index, self.co) + + def add(self, index, data, offset): + assert self.ngpu > 0 + + t0 = time.time() + nb = data.shape[0] + + for i0 in range(0, nb, self.add_batch_size): + i1 = min(i0 + self.add_batch_size, nb) + xs = data[i0:i1] + + self.gpu_index.add_with_ids(xs, np.arange(offset+i0, offset+i1)) + + if self.max_add > 0 and self.gpu_index.ntotal > self.max_add: + self._flush_to_cpu(index, nb, offset) + + print('\r%d/%d (%.3f s) ' % (i0, nb, time.time() - t0), end=' ') + sys.stdout.flush() + + if self.gpu_index.ntotal > 0: + self._flush_to_cpu(index, nb, offset) + + assert index.ntotal == offset+nb, (index.ntotal, offset+nb, offset, nb) + print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" % (time.time() - t0)) + + def _flush_to_cpu(self, index, nb, offset): + print("Flush indexes to CPU") + + for i in range(self.ngpu): + index_src_gpu = faiss.downcast_index(self.gpu_index if self.ngpu == 1 else self.gpu_index.at(i)) + index_src = faiss.index_gpu_to_cpu(index_src_gpu) + + index_src.copy_subset_to(index, 0, offset, offset+nb) + index_src_gpu.reset() + index_src_gpu.reserveMemory(self.max_add) + + if self.ngpu > 1: + try: + self.gpu_index.sync_with_shard_indexes() + except: + self.gpu_index.syncWithSubIndexes() diff --git a/colbert/indexing/index_manager.py b/colbert/indexing/index_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2a66f2852d4aaa1d6cf5f4eb7b466fa12d024a --- /dev/null +++ b/colbert/indexing/index_manager.py @@ -0,0 +1,22 @@ +import torch +import faiss +import numpy as np + +from colbert.utils.utils import print_message + + +class IndexManager(): + def __init__(self, dim): + self.dim = dim + + def save(self, tensor, path_prefix): + torch.save(tensor, path_prefix) + + +def load_index_part(filename, verbose=True): + part = torch.load(filename) + + if type(part) == list: # for backward compatibility + part = torch.cat(part) + + return part diff --git a/colbert/indexing/loaders.py b/colbert/indexing/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb8d67c49348edeb6a1f7537bd5cd6799caed34 --- /dev/null +++ b/colbert/indexing/loaders.py @@ -0,0 +1,34 @@ +import os +import torch +import ujson + +from math import ceil +from itertools import accumulate +from colbert.utils.utils import print_message + + +def get_parts(directory): + extension = '.pt' + + parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory) + if filename.endswith(extension)]) + + assert list(range(len(parts))) == parts, parts + + # Integer-sortedness matters. + parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts] + samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts] + + return parts, parts_paths, samples_paths + + +def load_doclens(directory, flatten=True): + parts, _, _ = get_parts(directory) + + doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts] + all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames] + + if flatten: + all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens] + + return all_doclens diff --git a/colbert/modeling/__init__.py b/colbert/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/modeling/__pycache__/__init__.cpython-37.pyc b/colbert/modeling/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..454905b6284026ca6952b5c7969e6836e668bf7c Binary files /dev/null and b/colbert/modeling/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/modeling/__pycache__/colbert.cpython-37.pyc b/colbert/modeling/__pycache__/colbert.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8af8ff79855fd068525693d619812d752af7ab85 Binary files /dev/null and b/colbert/modeling/__pycache__/colbert.cpython-37.pyc differ diff --git a/colbert/modeling/__pycache__/inference.cpython-37.pyc b/colbert/modeling/__pycache__/inference.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ed1a2102744796f39dc73b158373bb5562851a5 Binary files /dev/null and b/colbert/modeling/__pycache__/inference.cpython-37.pyc differ diff --git a/colbert/modeling/colbert.py b/colbert/modeling/colbert.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd267f1ba57c4ce6d5537652812028e0c8524e7 --- /dev/null +++ b/colbert/modeling/colbert.py @@ -0,0 +1,73 @@ +import string +import torch +import torch.nn as nn + +from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast +from colbert.parameters import DEVICE + + +class ColBERT(BertPreTrainedModel): + def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'): + + super(ColBERT, self).__init__(config) + + self.query_maxlen = query_maxlen + self.doc_maxlen = doc_maxlen + self.similarity_metric = similarity_metric + self.dim = dim + + self.mask_punctuation = mask_punctuation + self.skiplist = {} + + if self.mask_punctuation: + self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased') + self.skiplist = {w: True + for symbol in string.punctuation + for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]} + + self.bert = BertModel(config) + self.linear = nn.Linear(config.hidden_size, dim * 2, bias=False) + + self.init_weights() + + def forward(self, Q, D): + return self.score(self.query(*Q), self.doc(*D)) + + def query(self, input_ids, attention_mask): + input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE) + Q = self.bert(input_ids, attention_mask=attention_mask)[0] + Q = self.linear(Q) + Q = Q.split(int(Q.size(2)/2),2) + Q = torch.cat(Q,1) + + return torch.nn.functional.normalize(Q, p=2, dim=2) + + def doc(self, input_ids, attention_mask, keep_dims=True): + input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE) + D = self.bert(input_ids, attention_mask=attention_mask)[0] + D = self.linear(D) + D = D.split(int(D.size(2)/2),2) + D = torch.cat(D,1) + + mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float() + mask = torch.cat(2*[mask],1) + D = D * mask + + D = torch.nn.functional.normalize(D, p=2, dim=2) + + if not keep_dims: + D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1) + D = [d[mask[idx]] for idx, d in enumerate(D)] + + return D + + def score(self, Q, D): + if self.similarity_metric == 'cosine': + return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1) + + assert self.similarity_metric == 'l2' + return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1) + + def mask(self, input_ids): + mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()] + return mask diff --git a/colbert/modeling/inference.py b/colbert/modeling/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..10833e559cdd48850273856cd16d9742168573f7 --- /dev/null +++ b/colbert/modeling/inference.py @@ -0,0 +1,87 @@ +import torch + +from colbert.modeling.colbert import ColBERT +from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer +from colbert.utils.amp import MixedPrecisionManager +from colbert.parameters import DEVICE + + +class ModelInference(): + def __init__(self, colbert: ColBERT, amp=False): + assert colbert.training is False + + self.colbert = colbert + self.query_tokenizer = QueryTokenizer(colbert.query_maxlen) + self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen) + + self.amp_manager = MixedPrecisionManager(amp) + + def query(self, *args, to_cpu=False, **kw_args): + with torch.no_grad(): + with self.amp_manager.context(): + Q = self.colbert.query(*args, **kw_args) + return Q.cpu() if to_cpu else Q + + def doc(self, *args, to_cpu=False, **kw_args): + with torch.no_grad(): + with self.amp_manager.context(): + D = self.colbert.doc(*args, **kw_args) + return D.cpu() if to_cpu else D + + def queryFromText(self, queries, bsize=None, to_cpu=False): + if bsize: + batches = self.query_tokenizer.tensorize(queries, bsize=bsize) + batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches] + return torch.cat(batches) + + input_ids, attention_mask = self.query_tokenizer.tensorize(queries) + return self.query(input_ids, attention_mask) + + def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False): + if bsize: + batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize) + + batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu) + for input_ids, attention_mask in batches] + + if keep_dims: + D = _stack_3D_tensors(batches) + return D[reverse_indices] + + D = [d for batch in batches for d in batch] + return [D[idx] for idx in reverse_indices.tolist()] + + input_ids, attention_mask = self.doc_tokenizer.tensorize(docs) + return self.doc(input_ids, attention_mask, keep_dims=keep_dims) + + def score(self, Q, D, mask=None, lengths=None, explain=False): + if lengths is not None: + assert mask is None, "don't supply both mask and lengths" + + mask = torch.arange(D.size(1), device=DEVICE) + 1 + mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1) + + scores = (D @ Q) + scores = scores if mask is None else scores * mask.unsqueeze(-1) + scores = scores.max(1) + + if explain: + assert False, "TODO" + + return scores.values.sum(-1).cpu() + + +def _stack_3D_tensors(groups): + bsize = sum([x.size(0) for x in groups]) + maxlen = max([x.size(1) for x in groups]) + hdim = groups[0].size(2) + + output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype) + + offset = 0 + for x in groups: + endpos = offset + x.size(0) + output[offset:endpos, :x.size(1)] = x + offset = endpos + + return output diff --git a/colbert/modeling/tokenization/__init__.py b/colbert/modeling/tokenization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7cc0206006c3b35c00b94f43ca8b237ce7130e --- /dev/null +++ b/colbert/modeling/tokenization/__init__.py @@ -0,0 +1,3 @@ +from colbert.modeling.tokenization.query_tokenization import * +from colbert.modeling.tokenization.doc_tokenization import * +from colbert.modeling.tokenization.utils import tensorize_triples diff --git a/colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc b/colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d873f50ea9cdd3da7400ebfc57afba0c4d7b3209 Binary files /dev/null and b/colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc b/colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d28db848bfb91a34d3a913d57fe92d6640a7de4 Binary files /dev/null and b/colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc differ diff --git a/colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc b/colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7edf5822df39dd8877a3d1e7f23152932894caae Binary files /dev/null and b/colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc differ diff --git a/colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc b/colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fecb70bb92f96a2b74625a6e87ee555bb0bd1a47 Binary files /dev/null and b/colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc differ diff --git a/colbert/modeling/tokenization/doc_tokenization.py b/colbert/modeling/tokenization/doc_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..db5b91775098a26d7014e9a7d2793e44420cdb35 --- /dev/null +++ b/colbert/modeling/tokenization/doc_tokenization.py @@ -0,0 +1,63 @@ +import torch + +from transformers import BertTokenizerFast +from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length + + +class DocTokenizer(): + def __init__(self, doc_maxlen): + self.tok = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased') + self.doc_maxlen = doc_maxlen + + self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]') + self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id + self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id + + assert self.D_marker_token_id == 1 + + def tokenize(self, batch_text, add_special_tokens=False): + assert type(batch_text) in [list, tuple], (type(batch_text)) + + tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text] + + if not add_special_tokens: + return tokens + + prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token] + tokens = [prefix + lst + suffix for lst in tokens] + + return tokens + + def encode(self, batch_text, add_special_tokens=False): + assert type(batch_text) in [list, tuple], (type(batch_text)) + + ids = self.tok(batch_text, add_special_tokens=False)['input_ids'] + + if not add_special_tokens: + return ids + + prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id] + ids = [prefix + lst + suffix for lst in ids] + + return ids + + def tensorize(self, batch_text, bsize=None): + assert type(batch_text) in [list, tuple], (type(batch_text)) + + # add placehold for the [D] marker + batch_text = ['. ' + x for x in batch_text] + + obj = self.tok(batch_text, padding='longest', truncation='longest_first', + return_tensors='pt', max_length=self.doc_maxlen) + + ids, mask = obj['input_ids'], obj['attention_mask'] + + # postprocess for the [D] marker + ids[:, 1] = self.D_marker_token_id + + if bsize: + ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize) + batches = _split_into_batches(ids, mask, bsize) + return batches, reverse_indices + + return ids, mask diff --git a/colbert/modeling/tokenization/query_tokenization.py b/colbert/modeling/tokenization/query_tokenization.py new file mode 100644 index 0000000000000000000000000000000000000000..ca507b641b7b8587842530b48b1f4ed513210760 --- /dev/null +++ b/colbert/modeling/tokenization/query_tokenization.py @@ -0,0 +1,64 @@ +import torch + +from transformers import BertTokenizerFast +from colbert.modeling.tokenization.utils import _split_into_batches + + +class QueryTokenizer(): + def __init__(self, query_maxlen): + self.tok = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased') + self.query_maxlen = query_maxlen + + self.Q_marker_token, self.Q_marker_token_id = '[Q]', self.tok.convert_tokens_to_ids('[unused0]') + self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id + self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id + self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id + + assert self.Q_marker_token_id == 100 and self.mask_token_id == 103 + + def tokenize(self, batch_text, add_special_tokens=False): + assert type(batch_text) in [list, tuple], (type(batch_text)) + + tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text] + + if not add_special_tokens: + return tokens + + prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token] + tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens] + + return tokens + + def encode(self, batch_text, add_special_tokens=False): + assert type(batch_text) in [list, tuple], (type(batch_text)) + + ids = self.tok(batch_text, add_special_tokens=False)['input_ids'] + + if not add_special_tokens: + return ids + + prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id] + ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids] + + return ids + + def tensorize(self, batch_text, bsize=None): + assert type(batch_text) in [list, tuple], (type(batch_text)) + + # add placehold for the [Q] marker + batch_text = ['. ' + x for x in batch_text] + + obj = self.tok(batch_text, padding='max_length', truncation=True, + return_tensors='pt', max_length=self.query_maxlen) + + ids, mask = obj['input_ids'], obj['attention_mask'] + + # postprocess for the [Q] marker and the [MASK] augmentation + ids[:, 1] = self.Q_marker_token_id + ids[ids == 0] = self.mask_token_id + + if bsize: + batches = _split_into_batches(ids, mask, bsize) + return batches + + return ids, mask diff --git a/colbert/modeling/tokenization/utils.py b/colbert/modeling/tokenization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..001f2562b4475289c3b1001fa953e09f8468d3f8 --- /dev/null +++ b/colbert/modeling/tokenization/utils.py @@ -0,0 +1,51 @@ +import torch + + +def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize): + assert len(queries) == len(positives) == len(negatives) + assert bsize is None or len(queries) % bsize == 0 + + N = len(queries) + Q_ids, Q_mask = query_tokenizer.tensorize(queries) + D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives) + D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1) + + # Compute max among {length of i^th positive, length of i^th negative} for i \in N + maxlens = D_mask.sum(-1).max(0).values + + # Sort by maxlens + indices = maxlens.sort().indices + Q_ids, Q_mask = Q_ids[indices], Q_mask[indices] + D_ids, D_mask = D_ids[:, indices], D_mask[:, indices] + + (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask + + query_batches = _split_into_batches(Q_ids, Q_mask, bsize) + positive_batches = _split_into_batches(positive_ids, positive_mask, bsize) + negative_batches = _split_into_batches(negative_ids, negative_mask, bsize) + + batches = [] + for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches): + Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask))) + D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask))) + batches.append((Q, D)) + + return batches + + +def _sort_by_length(ids, mask, bsize): + if ids.size(0) <= bsize: + return ids, mask, torch.arange(ids.size(0)) + + indices = mask.sum(-1).sort().indices + reverse_indices = indices.sort().indices + + return ids[indices], mask[indices], reverse_indices + + +def _split_into_batches(ids, mask, bsize): + batches = [] + for offset in range(0, ids.size(0), bsize): + batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize])) + + return batches diff --git a/colbert/parameters.py b/colbert/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..bc70fb046f7e1d8937b71205b11dafc76c8de8c9 --- /dev/null +++ b/colbert/parameters.py @@ -0,0 +1,9 @@ +import torch + +DEVICE = torch.device("cuda") + +SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 300*1000, 400*1000] +SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000] +SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000] + +SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS) diff --git a/colbert/ranking/__init__.py b/colbert/ranking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/ranking/__pycache__/__init__.cpython-37.pyc b/colbert/ranking/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..768d0661dcccf44583bdcacbcdb641e18592d05a Binary files /dev/null and b/colbert/ranking/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/ranking/__pycache__/batch_retrieval.cpython-37.pyc b/colbert/ranking/__pycache__/batch_retrieval.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52ea439582353220a19b91fbab14eaed5cdfeade Binary files /dev/null and b/colbert/ranking/__pycache__/batch_retrieval.cpython-37.pyc differ diff --git a/colbert/ranking/__pycache__/faiss_index.cpython-37.pyc b/colbert/ranking/__pycache__/faiss_index.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b11a2c6b3cdadb60883d34e2068976f7e6fd0f21 Binary files /dev/null and b/colbert/ranking/__pycache__/faiss_index.cpython-37.pyc differ diff --git a/colbert/ranking/__pycache__/index_part.cpython-37.pyc b/colbert/ranking/__pycache__/index_part.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d16a386eddf6856abacf6d9e3ca02873e6df5737 Binary files /dev/null and b/colbert/ranking/__pycache__/index_part.cpython-37.pyc differ diff --git a/colbert/ranking/__pycache__/index_ranker.cpython-37.pyc b/colbert/ranking/__pycache__/index_ranker.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9809c14afc46572cf5f0cf8c45dd0b21d2ba391 Binary files /dev/null and b/colbert/ranking/__pycache__/index_ranker.cpython-37.pyc differ diff --git a/colbert/ranking/__pycache__/rankers.cpython-37.pyc b/colbert/ranking/__pycache__/rankers.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a6f8c254752f9fedd635f0a9f1a925f3e6d14be Binary files /dev/null and b/colbert/ranking/__pycache__/rankers.cpython-37.pyc differ diff --git a/colbert/ranking/__pycache__/retrieval.cpython-37.pyc b/colbert/ranking/__pycache__/retrieval.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d3b5efc00414da733c5592c95467d91bd64090 Binary files /dev/null and b/colbert/ranking/__pycache__/retrieval.cpython-37.pyc differ diff --git a/colbert/ranking/batch_reranking.py b/colbert/ranking/batch_reranking.py new file mode 100644 index 0000000000000000000000000000000000000000..121ae00cc8a1a7d214f7cb6f77b97b3cc238a68e --- /dev/null +++ b/colbert/ranking/batch_reranking.py @@ -0,0 +1,131 @@ +import os +import time +import torch +import queue +import threading + +from collections import defaultdict + +from colbert.utils.runs import Run +from colbert.modeling.inference import ModelInference +from colbert.evaluation.ranking_logger import RankingLogger + +from colbert.utils.utils import print_message, flatten, zipstar +from colbert.indexing.loaders import get_parts +from colbert.ranking.index_part import IndexPart + +MAX_DEPTH_LOGGED = 1000 # TODO: Use args.depth + + +def prepare_ranges(index_path, dim, step, part_range): + print_message("#> Launching a separate thread to load index parts asynchronously.") + parts, _, _ = get_parts(index_path) + + positions = [(offset, offset + step) for offset in range(0, len(parts), step)] + + if part_range is not None: + positions = positions[part_range.start: part_range.stop] + + loaded_parts = queue.Queue(maxsize=2) + + def _loader_thread(index_path, dim, positions): + for offset, endpos in positions: + index = IndexPart(index_path, dim=dim, part_range=range(offset, endpos), verbose=True) + loaded_parts.put(index, block=True) + + thread = threading.Thread(target=_loader_thread, args=(index_path, dim, positions,)) + thread.start() + + return positions, loaded_parts, thread + + +def score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids): + print_message("#> Sorting by PID..") + all_query_indexes, all_pids = zipstar(all_pids) + sorting_pids = torch.tensor(all_pids).sort() + all_query_indexes, all_pids = torch.tensor(all_query_indexes)[sorting_pids.indices], sorting_pids.values + + range_start, range_end = 0, 0 + + for offset, endpos in positions: + print_message(f"#> Fetching parts {offset}--{endpos} from queue..") + index = loaded_parts.get() + + print_message(f"#> Filtering PIDs to the range {index.pids_range}..") + range_start = range_start + (all_pids[range_start:] < index.pids_range.start).sum() + range_end = range_end + (all_pids[range_end:] < index.pids_range.stop).sum() + + pids = all_pids[range_start:range_end] + query_indexes = all_query_indexes[range_start:range_end] + + print_message(f"#> Got {len(pids)} query--passage pairs in this range.") + + if len(pids) == 0: + continue + + print_message(f"#> Ranking in batches the pairs #{range_start} through #{range_end}...") + scores = index.batch_rank(all_query_embeddings, query_indexes, pids, sorted_pids=True) + + for query_index, pid, score in zip(query_indexes.tolist(), pids.tolist(), scores): + all_query_rankings[0][query_index].append(pid) + all_query_rankings[1][query_index].append(score) + + +def batch_rerank(args): + positions, loaded_parts, thread = prepare_ranges(args.index_path, args.dim, args.step, args.part_range) + + inference = ModelInference(args.colbert, amp=args.amp) + queries, topK_pids = args.queries, args.topK_pids + + with torch.no_grad(): + queries_in_order = list(queries.values()) + + print_message(f"#> Encoding all {len(queries_in_order)} queries in batches...") + + all_query_embeddings = inference.queryFromText(queries_in_order, bsize=512, to_cpu=True) + all_query_embeddings = all_query_embeddings.to(dtype=torch.float16).permute(0, 2, 1).contiguous() + + for qid in queries: + """ + Since topK_pids is a defaultdict, make sure each qid *has* actual PID information (even if empty). + """ + assert qid in topK_pids, qid + + all_pids = flatten([[(query_index, pid) for pid in topK_pids[qid]] for query_index, qid in enumerate(queries)]) + all_query_rankings = [defaultdict(list), defaultdict(list)] + + print_message(f"#> Will process {len(all_pids)} query--document pairs in total.") + + with torch.no_grad(): + score_by_range(positions, loaded_parts, all_query_embeddings, all_query_rankings, all_pids) + + ranking_logger = RankingLogger(Run.path, qrels=None, log_scores=args.log_scores) + + with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger: + with torch.no_grad(): + for query_index, qid in enumerate(queries): + if query_index % 1000 == 0: + print_message("#> Logging query #{} (qid {}) now...".format(query_index, qid)) + + pids = all_query_rankings[0][query_index] + scores = all_query_rankings[1][query_index] + + K = min(MAX_DEPTH_LOGGED, len(scores)) + + if K == 0: + continue + + scores_topk = torch.tensor(scores).topk(K, largest=True, sorted=True) + + pids, scores = torch.tensor(pids)[scores_topk.indices].tolist(), scores_topk.values.tolist() + + ranking = [(score, pid, None) for pid, score in zip(pids, scores)] + assert len(ranking) <= MAX_DEPTH_LOGGED, (len(ranking), MAX_DEPTH_LOGGED) + + rlogger.log(qid, ranking, is_ranked=True, print_positions=[1, 2] if query_index % 100 == 0 else []) + + print('\n\n') + print(ranking_logger.filename) + print_message('#> Done.\n') + + thread.join() diff --git a/colbert/ranking/batch_retrieval.py b/colbert/ranking/batch_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..1497f8022499258eb7aa6b43cc63ef08bceb2fd5 --- /dev/null +++ b/colbert/ranking/batch_retrieval.py @@ -0,0 +1,50 @@ +import os +import time +import faiss +import random +import torch + +from colbert.utils.runs import Run +from multiprocessing import Pool +from colbert.modeling.inference import ModelInference +from colbert.evaluation.ranking_logger import RankingLogger + +from colbert.utils.utils import print_message, batch +from colbert.ranking.faiss_index import FaissIndex + + +def batch_retrieve(args): + assert args.retrieve_only, "TODO: Combine batch (multi-query) retrieval with batch re-ranking" + + faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, args.part_range) + inference = ModelInference(args.colbert, amp=args.amp) + + ranking_logger = RankingLogger(Run.path, qrels=None) + + with ranking_logger.context('unordered.tsv', also_save_annotations=False) as rlogger: + queries = args.queries + qids_in_order = list(queries.keys()) + + for qoffset, qbatch in batch(qids_in_order, 100_000, provide_offset=True): + qbatch_text = [queries[qid] for qid in qbatch] + + print_message(f"#> Embedding {len(qbatch_text)} queries in parallel...") + Q = inference.queryFromText(qbatch_text, bsize=512) + + print_message("#> Starting batch retrieval...") + all_pids = faiss_index.retrieve(args.faiss_depth, Q, verbose=True) + + # Log the PIDs with rank -1 for all + for query_idx, (qid, ranking) in enumerate(zip(qbatch, all_pids)): + query_idx = qoffset + query_idx + + if query_idx % 1000 == 0: + print_message(f"#> Logging query #{query_idx} (qid {qid}) now...") + + ranking = [(None, pid, None) for pid in ranking] + rlogger.log(qid, ranking, is_ranked=False) + + print('\n\n') + print(ranking_logger.filename) + print("#> Done.") + print('\n\n') diff --git a/colbert/ranking/faiss_index.py b/colbert/ranking/faiss_index.py new file mode 100644 index 0000000000000000000000000000000000000000..cb53be5f721d980ecaf427de3ce65c3775b506bc --- /dev/null +++ b/colbert/ranking/faiss_index.py @@ -0,0 +1,122 @@ +import os +import time +import faiss +import random +import torch + +from multiprocessing import Pool +from colbert.modeling.inference import ModelInference + +from colbert.utils.utils import print_message, flatten, batch +from colbert.indexing.loaders import load_doclens + + +class FaissIndex(): + def __init__(self, index_path, faiss_index_path, nprobe, part_range=None): + print_message("#> Loading the FAISS index from", faiss_index_path, "..") + + faiss_part_range = os.path.basename(faiss_index_path).split('.')[-2].split('-') + + if len(faiss_part_range) == 2: + faiss_part_range = range(*map(int, faiss_part_range)) + assert part_range[0] in faiss_part_range, (part_range, faiss_part_range) + assert part_range[-1] in faiss_part_range, (part_range, faiss_part_range) + else: + faiss_part_range = None + + self.part_range = part_range + self.faiss_part_range = faiss_part_range + + self.faiss_index = faiss.read_index(faiss_index_path) + self.faiss_index.nprobe = nprobe + + print_message("#> Building the emb2pid mapping..") + all_doclens = load_doclens(index_path, flatten=False) + + pid_offset = 0 + if faiss_part_range is not None: + print(f"#> Restricting all_doclens to the range {faiss_part_range}.") + pid_offset = len(flatten(all_doclens[:faiss_part_range.start])) + all_doclens = all_doclens[faiss_part_range.start:faiss_part_range.stop] + + self.relative_range = None + if self.part_range is not None: + start = self.faiss_part_range.start if self.faiss_part_range is not None else 0 + a = len(flatten(all_doclens[:self.part_range.start - start])) + b = len(flatten(all_doclens[:self.part_range.stop - start])) + self.relative_range = range(a, b) + print(f"self.relative_range = {self.relative_range}") + + all_doclens = flatten(all_doclens) + + total_num_embeddings = sum(all_doclens) + self.emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int) + + offset_doclens = 0 + for pid, dlength in enumerate(all_doclens): + self.emb2pid[offset_doclens: offset_doclens + dlength] = pid_offset + pid + offset_doclens += dlength + + print_message("len(self.emb2pid) =", len(self.emb2pid)) + + self.parallel_pool = Pool(16) + + def retrieve(self, faiss_depth, Q, verbose=False): + embedding_ids = self.queries_to_embedding_ids(faiss_depth, Q, verbose=verbose) + pids = self.embedding_ids_to_pids(embedding_ids, verbose=verbose) + + if self.relative_range is not None: + pids = [[pid for pid in pids_ if pid in self.relative_range] for pids_ in pids] + + return pids + + def queries_to_embedding_ids(self, faiss_depth, Q, verbose=True): + # Flatten into a matrix for the faiss search. + num_queries, embeddings_per_query, dim = Q.size() + Q_faiss = Q.view(num_queries * embeddings_per_query, dim).cpu().contiguous() + + # Search in large batches with faiss. + print_message("#> Search in batches with faiss. \t\t", + f"Q.size() = {Q.size()}, Q_faiss.size() = {Q_faiss.size()}", + condition=verbose) + + embeddings_ids = [] + faiss_bsize = embeddings_per_query * 5000 + for offset in range(0, Q_faiss.size(0), faiss_bsize): + endpos = min(offset + faiss_bsize, Q_faiss.size(0)) + + print_message("#> Searching from {} to {}...".format(offset, endpos), condition=verbose) + + some_Q_faiss = Q_faiss[offset:endpos].float().numpy() + _, some_embedding_ids = self.faiss_index.search(some_Q_faiss, faiss_depth) + embeddings_ids.append(torch.from_numpy(some_embedding_ids)) + + embedding_ids = torch.cat(embeddings_ids) + + # Reshape to (number of queries, non-unique embedding IDs per query) + embedding_ids = embedding_ids.view(num_queries, embeddings_per_query * embedding_ids.size(1)) + + return embedding_ids + + def embedding_ids_to_pids(self, embedding_ids, verbose=True): + # Find unique PIDs per query. + print_message("#> Lookup the PIDs..", condition=verbose) + all_pids = self.emb2pid[embedding_ids] + + print_message(f"#> Converting to a list [shape = {all_pids.size()}]..", condition=verbose) + all_pids = all_pids.tolist() + + print_message("#> Removing duplicates (in parallel if large enough)..", condition=verbose) + + if len(all_pids) > 5000: + all_pids = list(self.parallel_pool.map(uniq, all_pids)) + else: + all_pids = list(map(uniq, all_pids)) + + print_message("#> Done with embedding_ids_to_pids().", condition=verbose) + + return all_pids + + +def uniq(l): + return list(set(l)) diff --git a/colbert/ranking/index_part.py b/colbert/ranking/index_part.py new file mode 100644 index 0000000000000000000000000000000000000000..189476ef7fb22dc197882c2c1bb62f60d4890e0c --- /dev/null +++ b/colbert/ranking/index_part.py @@ -0,0 +1,82 @@ +import os +import torch +import ujson + +from math import ceil +from itertools import accumulate +from colbert.utils.utils import print_message, dotdict, flatten + +from colbert.indexing.loaders import get_parts, load_doclens +from colbert.indexing.index_manager import load_index_part +from colbert.ranking.index_ranker import IndexRanker + + +class IndexPart(): + def __init__(self, directory, dim=128, part_range=None, verbose=True): + first_part, last_part = (0, None) if part_range is None else (part_range.start, part_range.stop) + + # Load parts metadata + all_parts, all_parts_paths, _ = get_parts(directory) + self.parts = all_parts[first_part:last_part] + self.parts_paths = all_parts_paths[first_part:last_part] + + # Load doclens metadata + all_doclens = load_doclens(directory, flatten=False) + + self.doc_offset = sum([len(part_doclens) for part_doclens in all_doclens[:first_part]]) + self.doc_endpos = sum([len(part_doclens) for part_doclens in all_doclens[:last_part]]) + self.pids_range = range(self.doc_offset, self.doc_endpos) + + self.parts_doclens = all_doclens[first_part:last_part] + self.doclens = flatten(self.parts_doclens) + self.num_embeddings = sum(self.doclens) + + self.tensor = self._load_parts(dim, verbose) + self.ranker = IndexRanker(self.tensor, self.doclens) + + def _load_parts(self, dim, verbose): + tensor = torch.zeros(self.num_embeddings + 512, dim, dtype=torch.float16) + + if verbose: + print_message("tensor.size() = ", tensor.size()) + + offset = 0 + for idx, filename in enumerate(self.parts_paths): + print_message("|> Loading", filename, "...", condition=verbose) + + endpos = offset + sum(self.parts_doclens[idx]) + part = load_index_part(filename, verbose=verbose) + + tensor[offset:endpos] = part + offset = endpos + + return tensor + + def pid_in_range(self, pid): + return pid in self.pids_range + + def rank(self, Q, pids): + """ + Rank a single batch of Q x pids (e.g., 1k--10k pairs). + """ + + assert Q.size(0) in [1, len(pids)], (Q.size(0), len(pids)) + assert all(pid in self.pids_range for pid in pids), self.pids_range + + pids_ = [pid - self.doc_offset for pid in pids] + scores = self.ranker.rank(Q, pids_) + + return scores + + def batch_rank(self, all_query_embeddings, query_indexes, pids, sorted_pids): + """ + Rank a large, fairly dense set of query--passage pairs (e.g., 1M+ pairs). + Higher overhead, much faster for large batches. + """ + + assert ((pids >= self.pids_range.start) & (pids < self.pids_range.stop)).sum() == pids.size(0) + + pids_ = pids - self.doc_offset + scores = self.ranker.batch_rank(all_query_embeddings, query_indexes, pids_, sorted_pids) + + return scores diff --git a/colbert/ranking/index_ranker.py b/colbert/ranking/index_ranker.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b32293e3035e40a42dd9946bd5c556bc47e371 --- /dev/null +++ b/colbert/ranking/index_ranker.py @@ -0,0 +1,164 @@ +import os +import math +import torch +import ujson +import traceback + +from itertools import accumulate +from colbert.parameters import DEVICE +from colbert.utils.utils import print_message, dotdict, flatten + +BSIZE = 1 << 14 + + +class IndexRanker(): + def __init__(self, tensor, doclens): + self.tensor = tensor + self.doclens = doclens + + self.maxsim_dtype = torch.float32 + self.doclens_pfxsum = [0] + list(accumulate(self.doclens)) + + self.doclens = torch.tensor(self.doclens) + self.doclens_pfxsum = torch.tensor(self.doclens_pfxsum) + + self.dim = self.tensor.size(-1) + + self.strides = [torch_percentile(self.doclens, p) for p in [90]] + self.strides.append(self.doclens.max().item()) + self.strides = sorted(list(set(self.strides))) + + print_message(f"#> Using strides {self.strides}..") + + self.views = self._create_views(self.tensor) + self.buffers = self._create_buffers(BSIZE, self.tensor.dtype, {'cpu', 'cuda:0'}) + + def _create_views(self, tensor): + views = [] + + for stride in self.strides: + outdim = tensor.size(0) - stride + 1 + view = torch.as_strided(tensor, (outdim, stride, self.dim), (self.dim, self.dim, 1)) + views.append(view) + + return views + + def _create_buffers(self, max_bsize, dtype, devices): + buffers = {} + + for device in devices: + buffers[device] = [torch.zeros(max_bsize, stride, self.dim, dtype=dtype, + device=device, pin_memory=(device == 'cpu')) + for stride in self.strides] + + return buffers + + def rank(self, Q, pids, views=None, shift=0): + assert len(pids) > 0 + assert Q.size(0) in [1, len(pids)] + + Q = Q.contiguous().to(DEVICE).to(dtype=self.maxsim_dtype) + + views = self.views if views is None else views + VIEWS_DEVICE = views[0].device + + D_buffers = self.buffers[str(VIEWS_DEVICE)] + + raw_pids = pids if type(pids) is list else pids.tolist() + pids = torch.tensor(pids) if type(pids) is list else pids + + doclens, offsets = self.doclens[pids], self.doclens_pfxsum[pids] + + assignments = (doclens.unsqueeze(1) > torch.tensor(self.strides).unsqueeze(0) + 1e-6).sum(-1) + + one_to_n = torch.arange(len(raw_pids)) + output_pids, output_scores, output_permutation = [], [], [] + + for group_idx, stride in enumerate(self.strides): + locator = (assignments == group_idx) + + if locator.sum() < 1e-5: + continue + + group_pids, group_doclens, group_offsets = pids[locator], doclens[locator], offsets[locator] + group_Q = Q if Q.size(0) == 1 else Q[locator] + + group_offsets = group_offsets.to(VIEWS_DEVICE) - shift + group_offsets_uniq, group_offsets_expand = torch.unique_consecutive(group_offsets, return_inverse=True) + + D_size = group_offsets_uniq.size(0) + D = torch.index_select(views[group_idx], 0, group_offsets_uniq, out=D_buffers[group_idx][:D_size]) + D = D.to(DEVICE) + D = D[group_offsets_expand.to(DEVICE)].to(dtype=self.maxsim_dtype) + + mask = torch.arange(stride, device=DEVICE) + 1 + mask = mask.unsqueeze(0) <= group_doclens.to(DEVICE).unsqueeze(-1) + + scores = (D @ group_Q) * mask.unsqueeze(-1) + scores = scores.max(1).values.sum(-1).cpu() + + output_pids.append(group_pids) + output_scores.append(scores) + output_permutation.append(one_to_n[locator]) + + output_permutation = torch.cat(output_permutation).sort().indices + output_pids = torch.cat(output_pids)[output_permutation].tolist() + output_scores = torch.cat(output_scores)[output_permutation].tolist() + + assert len(raw_pids) == len(output_pids) + assert len(raw_pids) == len(output_scores) + assert raw_pids == output_pids + + return output_scores + + def batch_rank(self, all_query_embeddings, all_query_indexes, all_pids, sorted_pids): + assert sorted_pids is True + + ###### + + scores = [] + range_start, range_end = 0, 0 + + for pid_offset in range(0, len(self.doclens), 50_000): + pid_endpos = min(pid_offset + 50_000, len(self.doclens)) + + range_start = range_start + (all_pids[range_start:] < pid_offset).sum() + range_end = range_end + (all_pids[range_end:] < pid_endpos).sum() + + pids = all_pids[range_start:range_end] + query_indexes = all_query_indexes[range_start:range_end] + + print_message(f"###--> Got {len(pids)} query--passage pairs in this sub-range {(pid_offset, pid_endpos)}.") + + if len(pids) == 0: + continue + + print_message(f"###--> Ranking in batches the pairs #{range_start} through #{range_end} in this sub-range.") + + tensor_offset = self.doclens_pfxsum[pid_offset].item() + tensor_endpos = self.doclens_pfxsum[pid_endpos].item() + 512 + + collection = self.tensor[tensor_offset:tensor_endpos].to(DEVICE) + views = self._create_views(collection) + + print_message(f"#> Ranking in batches of {BSIZE} query--passage pairs...") + + for batch_idx, offset in enumerate(range(0, len(pids), BSIZE)): + if batch_idx % 100 == 0: + print_message("#> Processing batch #{}..".format(batch_idx)) + + endpos = offset + BSIZE + batch_query_index, batch_pids = query_indexes[offset:endpos], pids[offset:endpos] + + Q = all_query_embeddings[batch_query_index] + + scores.extend(self.rank(Q, batch_pids, views, shift=tensor_offset)) + + return scores + + +def torch_percentile(tensor, p): + assert p in range(1, 100+1) + assert tensor.dim() == 1 + + return tensor.kthvalue(int(p * tensor.size(0) / 100.0)).values.item() diff --git a/colbert/ranking/rankers.py b/colbert/ranking/rankers.py new file mode 100644 index 0000000000000000000000000000000000000000..a6cb3cf8c7313566d527ad91d967643d4a3e7b9e --- /dev/null +++ b/colbert/ranking/rankers.py @@ -0,0 +1,43 @@ +import torch + +from functools import partial + +from colbert.ranking.index_part import IndexPart +from colbert.ranking.faiss_index import FaissIndex +from colbert.utils.utils import flatten, zipstar + + +class Ranker(): + def __init__(self, args, inference, faiss_depth=1024): + self.inference = inference + self.faiss_depth = faiss_depth + + if faiss_depth is not None: + self.faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, part_range=args.part_range) + self.retrieve = partial(self.faiss_index.retrieve, self.faiss_depth) + + self.index = IndexPart(args.index_path, dim=inference.colbert.dim, part_range=args.part_range, verbose=True) + + def encode(self, queries): + assert type(queries) in [list, tuple], type(queries) + + Q = self.inference.queryFromText(queries, bsize=512 if len(queries) > 512 else None) + + return Q + + def rank(self, Q, pids=None): + pids = self.retrieve(Q, verbose=False)[0] if pids is None else pids + + assert type(pids) in [list, tuple], type(pids) + assert Q.size(0) == 1, (len(pids), Q.size()) + assert all(type(pid) is int for pid in pids) + + scores = [] + if len(pids) > 0: + Q = Q.permute(0, 2, 1) + scores = self.index.rank(Q, pids) + + scores_sorter = torch.tensor(scores).sort(descending=True) + pids, scores = torch.tensor(pids)[scores_sorter.indices].tolist(), scores_sorter.values.tolist() + + return pids, scores diff --git a/colbert/ranking/reranking.py b/colbert/ranking/reranking.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca4a12fefd516428c670ebfb60014e099b86b5c --- /dev/null +++ b/colbert/ranking/reranking.py @@ -0,0 +1,61 @@ +import os +import time +import faiss +import random +import torch + +from colbert.utils.runs import Run +from multiprocessing import Pool +from colbert.modeling.inference import ModelInference +from colbert.evaluation.ranking_logger import RankingLogger + +from colbert.utils.utils import print_message, batch +from colbert.ranking.rankers import Ranker + + +def rerank(args): + inference = ModelInference(args.colbert, amp=args.amp) + ranker = Ranker(args, inference, faiss_depth=None) + + ranking_logger = RankingLogger(Run.path, qrels=None) + milliseconds = 0 + + with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger: + queries = args.queries + qids_in_order = list(queries.keys()) + + for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True): + qbatch_text = [queries[qid] for qid in qbatch] + qbatch_pids = [args.topK_pids[qid] for qid in qbatch] + + rankings = [] + + for query_idx, (q, pids) in enumerate(zip(qbatch_text, qbatch_pids)): + torch.cuda.synchronize('cuda:0') + s = time.time() + + Q = ranker.encode([q]) + pids, scores = ranker.rank(Q, pids=pids) + + torch.cuda.synchronize() + milliseconds += (time.time() - s) * 1000.0 + + if len(pids): + print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0], + milliseconds / (qoffset+query_idx+1), 'ms') + + rankings.append(zip(pids, scores)) + + for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)): + query_idx = qoffset + query_idx + + if query_idx % 100 == 0: + print_message(f"#> Logging query #{query_idx} (qid {qid}) now...") + + ranking = [(score, pid, None) for pid, score in ranking] + rlogger.log(qid, ranking, is_ranked=True) + + print('\n\n') + print(ranking_logger.filename) + print("#> Done.") + print('\n\n') diff --git a/colbert/ranking/retrieval.py b/colbert/ranking/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3ec6bdc5e682502b4d1eb68ce355db7a409795 --- /dev/null +++ b/colbert/ranking/retrieval.py @@ -0,0 +1,61 @@ +import os +import time +import faiss +import random +import torch +import itertools + +from colbert.utils.runs import Run +from multiprocessing import Pool +from colbert.modeling.inference import ModelInference +from colbert.evaluation.ranking_logger import RankingLogger + +from colbert.utils.utils import print_message, batch +from colbert.ranking.rankers import Ranker + + +def retrieve(args): + inference = ModelInference(args.colbert, amp=args.amp) + ranker = Ranker(args, inference, faiss_depth=args.faiss_depth) + + ranking_logger = RankingLogger(Run.path, qrels=None) + milliseconds = 0 + + with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger: + queries = args.queries + qids_in_order = list(queries.keys()) + + for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True): + qbatch_text = [queries[qid] for qid in qbatch] + + rankings = [] + + for query_idx, q in enumerate(qbatch_text): + torch.cuda.synchronize('cuda:0') + s = time.time() + + Q = ranker.encode([q]) + pids, scores = ranker.rank(Q) + + torch.cuda.synchronize() + milliseconds += (time.time() - s) * 1000.0 + + if len(pids): + print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0], + milliseconds / (qoffset+query_idx+1), 'ms') + + rankings.append(zip(pids, scores)) + + for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)): + query_idx = qoffset + query_idx + + if query_idx % 100 == 0: + print_message(f"#> Logging query #{query_idx} (qid {qid}) now...") + + ranking = [(score, pid, None) for pid, score in itertools.islice(ranking, args.depth)] + rlogger.log(qid, ranking, is_ranked=True) + + print('\n\n') + print(ranking_logger.filename) + print("#> Done.") + print('\n\n') diff --git a/colbert/rerank.py b/colbert/rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..b5713ec0c00e3da958d7244323e411bfba86c6df --- /dev/null +++ b/colbert/rerank.py @@ -0,0 +1,50 @@ +import os +import random + +from colbert.utils.parser import Arguments +from colbert.utils.runs import Run + +from colbert.evaluation.loaders import load_colbert, load_qrels, load_queries, load_topK_pids +from colbert.ranking.reranking import rerank +from colbert.ranking.batch_reranking import batch_rerank + + +def main(): + random.seed(12345) + + parser = Arguments(description='Re-ranking over a ColBERT index') + + parser.add_model_parameters() + parser.add_model_inference_parameters() + parser.add_reranking_input() + parser.add_index_use_input() + + parser.add_argument('--step', dest='step', default=1, type=int) + parser.add_argument('--part-range', dest='part_range', default=None, type=str) + parser.add_argument('--log-scores', dest='log_scores', default=False, action='store_true') + parser.add_argument('--batch', dest='batch', default=False, action='store_true') + parser.add_argument('--depth', dest='depth', default=1000, type=int) + + args = parser.parse() + + if args.part_range: + part_offset, part_endpos = map(int, args.part_range.split('..')) + args.part_range = range(part_offset, part_endpos) + + with Run.context(): + args.colbert, args.checkpoint = load_colbert(args) + + args.queries = load_queries(args.queries) + args.qrels = load_qrels(args.qrels) + args.topK_pids, args.qrels = load_topK_pids(args.topK, qrels=args.qrels) + + args.index_path = os.path.join(args.index_root, args.index_name) + + if args.batch: + batch_rerank(args) + else: + rerank(args) + + +if __name__ == "__main__": + main() diff --git a/colbert/retrieve.py b/colbert/retrieve.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc53f4d80563ab750ab573f3152846481397462 --- /dev/null +++ b/colbert/retrieve.py @@ -0,0 +1,56 @@ +import os +import random + +from colbert.utils.parser import Arguments +from colbert.utils.runs import Run + +from colbert.evaluation.loaders import load_colbert, load_qrels, load_queries +from colbert.indexing.faiss import get_faiss_index_name +from colbert.ranking.retrieval import retrieve +from colbert.ranking.batch_retrieval import batch_retrieve + + +def main(): + random.seed(12345) + + parser = Arguments(description='End-to-end retrieval and ranking with ColBERT.') + + parser.add_model_parameters() + parser.add_model_inference_parameters() + parser.add_ranking_input() + parser.add_retrieval_input() + + parser.add_argument('--faiss_name', dest='faiss_name', default=None, type=str) + parser.add_argument('--faiss_depth', dest='faiss_depth', default=1024, type=int) + parser.add_argument('--part-range', dest='part_range', default=None, type=str) + parser.add_argument('--batch', dest='batch', default=False, action='store_true') + parser.add_argument('--depth', dest='depth', default=1000, type=int) + + args = parser.parse() + + args.depth = args.depth if args.depth > 0 else None + + if args.part_range: + part_offset, part_endpos = map(int, args.part_range.split('..')) + args.part_range = range(part_offset, part_endpos) + + with Run.context(): + args.colbert, args.checkpoint = load_colbert(args) + args.qrels = load_qrels(args.qrels) + args.queries = load_queries(args.queries) + + args.index_path = os.path.join(args.index_root, args.index_name) + + if args.faiss_name is not None: + args.faiss_index_path = os.path.join(args.index_path, args.faiss_name) + else: + args.faiss_index_path = os.path.join(args.index_path, get_faiss_index_name(args)) + + if args.batch: + batch_retrieve(args) + else: + retrieve(args) + + +if __name__ == "__main__": + main() diff --git a/colbert/test.py b/colbert/test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4069472beb5b7d34cee84ce13fb56d9b0dcdd0d --- /dev/null +++ b/colbert/test.py @@ -0,0 +1,49 @@ +import os +import random + +from colbert.utils.parser import Arguments +from colbert.utils.runs import Run + +from colbert.evaluation.loaders import load_colbert, load_topK, load_qrels +from colbert.evaluation.loaders import load_queries, load_topK_pids, load_collection +from colbert.evaluation.ranking import evaluate +from colbert.evaluation.metrics import evaluate_recall + + +def main(): + random.seed(12345) + + parser = Arguments(description='Exhaustive (slow, not index-based) evaluation of re-ranking with ColBERT.') + + parser.add_model_parameters() + parser.add_model_inference_parameters() + parser.add_reranking_input() + + parser.add_argument('--depth', dest='depth', required=False, default=None, type=int) + + args = parser.parse() + + with Run.context(): + args.colbert, args.checkpoint = load_colbert(args) + args.qrels = load_qrels(args.qrels) + + if args.collection or args.queries: + assert args.collection and args.queries + + args.queries = load_queries(args.queries) + args.collection = load_collection(args.collection) + args.topK_pids, args.qrels = load_topK_pids(args.topK, args.qrels) + + else: + args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK) + + assert (not args.shortcircuit) or args.qrels, \ + "Short-circuiting (i.e., applying minimal computation to queries with no positives in the re-ranked set) " \ + "can only be applied if qrels is provided." + + evaluate_recall(args.qrels, args.queries, args.topK_pids) + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/colbert/train.py b/colbert/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e6611a91b38806a4c2a90da49ce5fcd2808bd1da --- /dev/null +++ b/colbert/train.py @@ -0,0 +1,34 @@ +import os +import random +import torch +import copy + +import colbert.utils.distributed as distributed + +from colbert.utils.parser import Arguments +from colbert.utils.runs import Run +from colbert.training.training import train + + +def main(): + parser = Arguments(description='Training ColBERT with triples.') + + parser.add_model_parameters() + parser.add_model_training_parameters() + parser.add_training_input() + + args = parser.parse() + + assert args.bsize % args.accumsteps == 0, ((args.bsize, args.accumsteps), + "The batch size must be divisible by the number of gradient accumulation steps.") + assert args.query_maxlen <= 512 + assert args.doc_maxlen <= 512 + + args.lazy = args.collection is not None + + with Run.context(consider_failed_if_interrupted=False): + train(args) + + +if __name__ == "__main__": + main() diff --git a/colbert/training/__init__.py b/colbert/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/training/__pycache__/__init__.cpython-37.pyc b/colbert/training/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3334ce32b3f6eaecd8ee190603a10a44e3b05a7 Binary files /dev/null and b/colbert/training/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/training/__pycache__/eager_batcher.cpython-37.pyc b/colbert/training/__pycache__/eager_batcher.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e1ca7ab65ca1dfe3a37ce35981e4fb644fc8961 Binary files /dev/null and b/colbert/training/__pycache__/eager_batcher.cpython-37.pyc differ diff --git a/colbert/training/__pycache__/lazy_batcher.cpython-37.pyc b/colbert/training/__pycache__/lazy_batcher.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cd503f27ace0dec32c0203c5ca88fdd2af4a386 Binary files /dev/null and b/colbert/training/__pycache__/lazy_batcher.cpython-37.pyc differ diff --git a/colbert/training/__pycache__/training.cpython-37.pyc b/colbert/training/__pycache__/training.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de5fe6af6e0e8be54bf9be1f88b22fe4d92cf40f Binary files /dev/null and b/colbert/training/__pycache__/training.cpython-37.pyc differ diff --git a/colbert/training/__pycache__/utils.cpython-37.pyc b/colbert/training/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a49c87423e5ca21149916e14fcc0ca7123ad7cb Binary files /dev/null and b/colbert/training/__pycache__/utils.cpython-37.pyc differ diff --git a/colbert/training/eager_batcher.py b/colbert/training/eager_batcher.py new file mode 100644 index 0000000000000000000000000000000000000000..a604174962a6713db03baace9b29a8139afd2b77 --- /dev/null +++ b/colbert/training/eager_batcher.py @@ -0,0 +1,62 @@ +import os +import ujson + +from functools import partial +from colbert.utils.utils import print_message +from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples + +from colbert.utils.runs import Run + + +class EagerBatcher(): + def __init__(self, args, rank=0, nranks=1): + self.rank, self.nranks = rank, nranks + self.bsize, self.accumsteps = args.bsize, args.accumsteps + + self.query_tokenizer = QueryTokenizer(args.query_maxlen) + self.doc_tokenizer = DocTokenizer(args.doc_maxlen) + self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer) + + self.triples_path = args.triples + self._reset_triples() + + def _reset_triples(self): + self.reader = open(self.triples_path, mode='r', encoding="utf-8") + self.position = 0 + + def __iter__(self): + return self + + def __next__(self): + queries, positives, negatives = [], [], [] + + for line_idx, line in zip(range(self.bsize * self.nranks), self.reader): + if (self.position + line_idx) % self.nranks != self.rank: + continue + + query, pos, neg = line.strip().split('\t') + + queries.append(query) + positives.append(pos) + negatives.append(neg) + + self.position += line_idx + 1 + + if len(queries) < self.bsize: + raise StopIteration + + return self.collate(queries, positives, negatives) + + def collate(self, queries, positives, negatives): + assert len(queries) == len(positives) == len(negatives) == self.bsize + + return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps) + + def skip_to_batch(self, batch_idx, intended_batch_size): + self._reset_triples() + + Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') + + _ = [self.reader.readline() for _ in range(batch_idx * intended_batch_size)] + + return None diff --git a/colbert/training/lazy_batcher.py b/colbert/training/lazy_batcher.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf28eb980bd0b6ccbeb2db162e0f5136d6205fd --- /dev/null +++ b/colbert/training/lazy_batcher.py @@ -0,0 +1,103 @@ +import os +import ujson + +from functools import partial +from colbert.utils.utils import print_message +from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples + +from colbert.utils.runs import Run + + +class LazyBatcher(): + def __init__(self, args, rank=0, nranks=1): + self.bsize, self.accumsteps = args.bsize, args.accumsteps + + self.query_tokenizer = QueryTokenizer(args.query_maxlen) + self.doc_tokenizer = DocTokenizer(args.doc_maxlen) + self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer) + self.position = 0 + + self.triples = self._load_triples(args.triples, rank, nranks) + self.queries = self._load_queries(args.queries) + self.collection = self._load_collection(args.collection) + + def _load_triples(self, path, rank, nranks): + """ + NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling. + In particular, each subset is perfectly represented in every batch! However, since we never + repeat passes over the data, we never repeat any particular triple, and the split across + nodes is random (since the underlying file is pre-shuffled), there's no concern here. + """ + print_message("#> Loading triples...") + + triples = [] + + with open(path) as f: + for line_idx, line in enumerate(f): + if line_idx % nranks == rank: + qid, pos, neg = ujson.loads(line) + triples.append((qid, pos, neg)) + + return triples + + def _load_queries(self, path): + print_message("#> Loading queries...") + + queries = {} + + with open(path) as f: + for line in f: + qid, query = line.strip().split('\t') + qid = int(qid) + queries[qid] = query + + return queries + + def _load_collection(self, path): + print_message("#> Loading collection...") + + collection = [] + + with open(path) as f: + for line_idx, line in enumerate(f): + pid, passage, title, *_ = line.strip().split('\t') + assert pid == 'id' or int(pid) == line_idx + + passage = title + ' | ' + passage + collection.append(passage) + + return collection + + def __iter__(self): + return self + + def __len__(self): + return len(self.triples) + + def __next__(self): + offset, endpos = self.position, min(self.position + self.bsize, len(self.triples)) + self.position = endpos + + if offset + self.bsize > len(self.triples): + raise StopIteration + + queries, positives, negatives = [], [], [] + + for position in range(offset, endpos): + query, pos, neg = self.triples[position] + query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg] + + queries.append(query) + positives.append(pos) + negatives.append(neg) + + return self.collate(queries, positives, negatives) + + def collate(self, queries, positives, negatives): + assert len(queries) == len(positives) == len(negatives) == self.bsize + + return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps) + + def skip_to_batch(self, batch_idx, intended_batch_size): + Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') + self.position = intended_batch_size * batch_idx diff --git a/colbert/training/training.py b/colbert/training/training.py new file mode 100644 index 0000000000000000000000000000000000000000..cabaa7845d1985f494daa23e4f0a784755a89aef --- /dev/null +++ b/colbert/training/training.py @@ -0,0 +1,123 @@ +import os +import random +import time +import torch +import torch.nn as nn +import numpy as np + +from transformers import AdamW +from colbert.utils.runs import Run +from colbert.utils.amp import MixedPrecisionManager + +from colbert.training.lazy_batcher import LazyBatcher +from colbert.training.eager_batcher import EagerBatcher +from colbert.parameters import DEVICE + +from colbert.modeling.colbert import ColBERT +from colbert.utils.utils import print_message +from colbert.training.utils import print_progress, manage_checkpoints + + +def train(args): + random.seed(12345) + np.random.seed(12345) + torch.manual_seed(12345) + if args.distributed: + torch.cuda.manual_seed_all(12345) + + if args.distributed: + assert args.bsize % args.nranks == 0, (args.bsize, args.nranks) + assert args.accumsteps == 1 + args.bsize = args.bsize // args.nranks + + print("Using args.bsize =", args.bsize, "(per process) and args.accumsteps =", args.accumsteps) + + if args.lazy: + reader = LazyBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks) + else: + reader = EagerBatcher(args, (0 if args.rank == -1 else args.rank), args.nranks) + + if args.rank not in [-1, 0]: + torch.distributed.barrier() + + colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased', + query_maxlen=args.query_maxlen, + doc_maxlen=args.doc_maxlen, + dim=args.dim, + similarity_metric=args.similarity, + mask_punctuation=args.mask_punctuation) + + if args.checkpoint is not None: + assert args.resume_optimizer is False, "TODO: This would mean reload optimizer too." + print_message(f"#> Starting from checkpoint {args.checkpoint} -- but NOT the optimizer!") + + checkpoint = torch.load(args.checkpoint, map_location='cpu') + + try: + colbert.load_state_dict(checkpoint['model_state_dict']) + except: + print_message("[WARNING] Loading checkpoint with strict=False") + colbert.load_state_dict(checkpoint['model_state_dict'], strict=False) + + if args.rank == 0: + torch.distributed.barrier() + + colbert = colbert.to(DEVICE) + colbert.train() + + if args.distributed: + colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[args.rank], + output_device=args.rank, + find_unused_parameters=True) + + optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=args.lr, eps=1e-8) + optimizer.zero_grad() + + amp = MixedPrecisionManager(args.amp) + criterion = nn.CrossEntropyLoss() + labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE) + + start_time = time.time() + train_loss = 0.0 + + start_batch_idx = 0 + + if args.resume: + assert args.checkpoint is not None + start_batch_idx = checkpoint['batch'] + + reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize']) + + for batch_idx, BatchSteps in zip(range(start_batch_idx, args.maxsteps), reader): + this_batch_loss = 0.0 + + for queries, passages in BatchSteps: + with amp.context(): + scores = colbert(queries, passages).view(2, -1).permute(1, 0) + loss = criterion(scores, labels[:scores.size(0)]) + loss = loss / args.accumsteps + + if args.rank < 1: + print_progress(scores) + + amp.backward(loss) + + train_loss += loss.item() + this_batch_loss += loss.item() + + amp.step(colbert, optimizer) + + if args.rank < 1: + avg_loss = train_loss / (batch_idx+1) + + num_examples_seen = (batch_idx - start_batch_idx) * args.bsize * args.nranks + elapsed = float(time.time() - start_time) + + log_to_mlflow = (batch_idx % 20 == 0) + Run.log_metric('train/avg_loss', avg_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) + Run.log_metric('train/batch_loss', this_batch_loss, step=batch_idx, log_to_mlflow=log_to_mlflow) + Run.log_metric('train/examples', num_examples_seen, step=batch_idx, log_to_mlflow=log_to_mlflow) + Run.log_metric('train/throughput', num_examples_seen / elapsed, step=batch_idx, log_to_mlflow=log_to_mlflow) + + print_message(batch_idx, avg_loss) + manage_checkpoints(args, colbert, optimizer, batch_idx+1) diff --git a/colbert/training/utils.py b/colbert/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61fa1990a0d9ee7d21c76a2960f32c46c4456646 --- /dev/null +++ b/colbert/training/utils.py @@ -0,0 +1,28 @@ +import os +import torch + +from colbert.utils.runs import Run +from colbert.utils.utils import print_message, save_checkpoint +from colbert.parameters import SAVED_CHECKPOINTS + + +def print_progress(scores): + positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2) + print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg) + + +def manage_checkpoints(args, colbert, optimizer, batch_idx): + arguments = args.input_arguments.__dict__ + + path = os.path.join(Run.path, 'checkpoints') + + if not os.path.exists(path): + os.mkdir(path) + + if batch_idx % 2000 == 0: + name = os.path.join(path, "colbert.dnn") + save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) + + if batch_idx in SAVED_CHECKPOINTS: + name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) + save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) diff --git a/colbert/utils/__init__.py b/colbert/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/colbert/utils/__pycache__/__init__.cpython-37.pyc b/colbert/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b085426596af40e6b96250095797186102d1a950 Binary files /dev/null and b/colbert/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/colbert/utils/__pycache__/amp.cpython-37.pyc b/colbert/utils/__pycache__/amp.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a61988a26ecf2f492d92da50ad1ce699f8f47aa Binary files /dev/null and b/colbert/utils/__pycache__/amp.cpython-37.pyc differ diff --git a/colbert/utils/__pycache__/distributed.cpython-37.pyc b/colbert/utils/__pycache__/distributed.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91728700abe17641c9311bc9315fa0b2e36b6b45 Binary files /dev/null and b/colbert/utils/__pycache__/distributed.cpython-37.pyc differ diff --git a/colbert/utils/__pycache__/logging.cpython-37.pyc b/colbert/utils/__pycache__/logging.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f07afce45b6adc6c124244d1574e6169446813d Binary files /dev/null and b/colbert/utils/__pycache__/logging.cpython-37.pyc differ diff --git a/colbert/utils/__pycache__/parser.cpython-37.pyc b/colbert/utils/__pycache__/parser.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cd9724fd03f54b9888e4270f1a45686a3725728 Binary files /dev/null and b/colbert/utils/__pycache__/parser.cpython-37.pyc differ diff --git a/colbert/utils/__pycache__/runs.cpython-37.pyc b/colbert/utils/__pycache__/runs.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94b92ca9cff96e3937da04644fb1f10b25ec1ffb Binary files /dev/null and b/colbert/utils/__pycache__/runs.cpython-37.pyc differ diff --git a/colbert/utils/__pycache__/utils.cpython-37.pyc b/colbert/utils/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f86ed34b456797cd68a51773d6b40b7465c1762a Binary files /dev/null and b/colbert/utils/__pycache__/utils.cpython-37.pyc differ diff --git a/colbert/utils/amp.py b/colbert/utils/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2a20f0bc3fe6f2f85594ff17ae1c1626390581 --- /dev/null +++ b/colbert/utils/amp.py @@ -0,0 +1,38 @@ +import torch + +from contextlib import contextmanager +from colbert.utils.utils import NullContextManager + +PyTorch_over_1_6 = float('.'.join(torch.__version__.split('.')[0:2])) >= 1.6 + + +class MixedPrecisionManager(): + def __init__(self, activated): + assert (not activated) or PyTorch_over_1_6, "Cannot use AMP for PyTorch version < 1.6" + + self.activated = activated + + if self.activated: + self.scaler = torch.cuda.amp.GradScaler() + + def context(self): + return torch.cuda.amp.autocast() if self.activated else NullContextManager() + + def backward(self, loss): + if self.activated: + self.scaler.scale(loss).backward() + else: + loss.backward() + + def step(self, colbert, optimizer): + if self.activated: + self.scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) + + self.scaler.step(optimizer) + self.scaler.update() + optimizer.zero_grad() + else: + torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) + optimizer.step() + optimizer.zero_grad() diff --git a/colbert/utils/distributed.py b/colbert/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d431e9752ff1a3f89a8042c4abd39cff0efcae --- /dev/null +++ b/colbert/utils/distributed.py @@ -0,0 +1,27 @@ +import os +import random +import torch +import numpy as np + + +def init(rank): + nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) + nranks = max(1, nranks) + # nranks = -1 + # is_distributed = nranks > 0 + is_distributed = False + + if rank == 0: + print('nranks =', nranks, '\t num_gpus =', torch.cuda.device_count()) + + if is_distributed: + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + + return nranks, is_distributed + + +def barrier(rank): + if rank >= 0: + torch.distributed.barrier() diff --git a/colbert/utils/logging.py b/colbert/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..ff31184d73da9b7224cff350a235c1e5558ca1db --- /dev/null +++ b/colbert/utils/logging.py @@ -0,0 +1,99 @@ +import os +import sys +import ujson +import mlflow +import traceback + +from torch.utils.tensorboard import SummaryWriter +from colbert.utils.utils import print_message, create_directory + + +class Logger(): + def __init__(self, rank, run): + self.rank = rank + self.is_main = self.rank in [-1, 0] + self.run = run + self.logs_path = os.path.join(self.run.path, "logs/") + + if self.is_main: + self._init_mlflow() + self.initialized_tensorboard = False + create_directory(self.logs_path) + + def _init_mlflow(self): + mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/")) + mlflow.set_experiment('/'.join([self.run.experiment, self.run.script])) + + mlflow.set_tag('experiment', self.run.experiment) + mlflow.set_tag('name', self.run.name) + mlflow.set_tag('path', self.run.path) + + def _init_tensorboard(self): + root = os.path.join(self.run.experiments_root, "logs/tensorboard/") + logdir = '__'.join([self.run.experiment, self.run.script, self.run.name]) + logdir = os.path.join(root, logdir) + + self.writer = SummaryWriter(log_dir=logdir) + self.initialized_tensorboard = True + + def _log_exception(self, etype, value, tb): + if not self.is_main: + return + + output_path = os.path.join(self.logs_path, 'exception.txt') + trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n' + print_message(trace, '\n\n') + + self.log_new_artifact(output_path, trace) + + def _log_all_artifacts(self): + if not self.is_main: + return + + mlflow.log_artifacts(self.logs_path) + + def _log_args(self, args): + if not self.is_main: + return + + for key in vars(args): + value = getattr(args, key) + if type(value) in [int, float, str, bool]: + mlflow.log_param(key, value) + + with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata: + ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4) + output_metadata.write('\n') + + with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata: + output_metadata.write(' '.join(sys.argv) + '\n') + + def log_metric(self, name, value, step, log_to_mlflow=True): + if not self.is_main: + return + + if not self.initialized_tensorboard: + self._init_tensorboard() + + if log_to_mlflow: + mlflow.log_metric(name, value, step=step) + self.writer.add_scalar(name, value, step) + + def log_new_artifact(self, path, content): + with open(path, 'w') as f: + f.write(content) + + mlflow.log_artifact(path) + + def warn(self, *args): + msg = print_message('[WARNING]', '\t', *args) + + with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata: + output_metadata.write(msg + '\n\n\n') + + def info_all(self, *args): + print_message('[' + str(self.rank) + ']', '\t', *args) + + def info(self, *args): + if self.is_main: + print_message(*args) diff --git a/colbert/utils/parser.py b/colbert/utils/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..d107b18ec068ab18f2aa253235d86977395ab8fe --- /dev/null +++ b/colbert/utils/parser.py @@ -0,0 +1,114 @@ +import os +import copy +import faiss + +from argparse import ArgumentParser + +import colbert.utils.distributed as distributed +from colbert.utils.runs import Run +from colbert.utils.utils import print_message, timestamp, create_directory + + +class Arguments(): + def __init__(self, description): + self.parser = ArgumentParser(description=description) + self.checks = [] + + self.add_argument('--root', dest='root', default='experiments') + self.add_argument('--experiment', dest='experiment', default='dirty') + self.add_argument('--run', dest='run', default=Run.name) + + self.add_argument('--local_rank', dest='rank', default=-1, type=int) + + def add_model_parameters(self): + # Core Arguments + self.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2']) + self.add_argument('--dim', dest='dim', default=128, type=int) + self.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int) + self.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int) + + # Filtering-related Arguments + self.add_argument('--mask-punctuation', dest='mask_punctuation', default=False, action='store_true') + + def add_model_training_parameters(self): + # NOTE: Providing a checkpoint is one thing, --resume is another, --resume_optimizer is yet another. + self.add_argument('--resume', dest='resume', default=False, action='store_true') + self.add_argument('--resume_optimizer', dest='resume_optimizer', default=False, action='store_true') + self.add_argument('--checkpoint', dest='checkpoint', default=None, required=False) + + self.add_argument('--lr', dest='lr', default=3e-06, type=float) + self.add_argument('--maxsteps', dest='maxsteps', default=400000, type=int) + self.add_argument('--bsize', dest='bsize', default=32, type=int) + self.add_argument('--accum', dest='accumsteps', default=2, type=int) + self.add_argument('--amp', dest='amp', default=False, action='store_true') + + def add_model_inference_parameters(self): + self.add_argument('--checkpoint', dest='checkpoint', required=True) + self.add_argument('--bsize', dest='bsize', default=128, type=int) + self.add_argument('--amp', dest='amp', default=False, action='store_true') + + def add_training_input(self): + self.add_argument('--triples', dest='triples', required=True) + self.add_argument('--queries', dest='queries', default=None) + self.add_argument('--collection', dest='collection', default=None) + + def check_training_input(args): + assert (args.collection is None) == (args.queries is None), \ + "For training, both (or neither) --collection and --queries must be supplied." \ + "If neither is supplied, the --triples file must contain texts (not PIDs)." + + self.checks.append(check_training_input) + + def add_ranking_input(self): + self.add_argument('--queries', dest='queries', default=None) + self.add_argument('--collection', dest='collection', default=None) + self.add_argument('--qrels', dest='qrels', default=None) + + def add_reranking_input(self): + self.add_ranking_input() + self.add_argument('--topk', dest='topK', required=True) + self.add_argument('--shortcircuit', dest='shortcircuit', default=False, action='store_true') + + def add_indexing_input(self): + self.add_argument('--collection', dest='collection', required=True) + self.add_argument('--index_root', dest='index_root', required=True) + self.add_argument('--index_name', dest='index_name', required=True) + + def add_index_use_input(self): + self.add_argument('--index_root', dest='index_root', required=True) + self.add_argument('--index_name', dest='index_name', required=True) + self.add_argument('--partitions', dest='partitions', default=None, type=int) + + def add_retrieval_input(self): + self.add_index_use_input() + self.add_argument('--nprobe', dest='nprobe', default=10, type=int) + self.add_argument('--retrieve_only', dest='retrieve_only', default=False, action='store_true') + + def add_argument(self, *args, **kw_args): + return self.parser.add_argument(*args, **kw_args) + + def check_arguments(self, args): + for check in self.checks: + check(args) + + def parse(self): + args = self.parser.parse_args() + self.check_arguments(args) + + args.input_arguments = copy.deepcopy(args) + + args.nranks, args.distributed = distributed.init(args.rank) + + args.nthreads = int(max(os.cpu_count(), faiss.omp_get_max_threads()) * 0.8) + args.nthreads = max(1, args.nthreads // args.nranks) + + if args.nranks > 1: + print_message(f"#> Restricting number of threads for FAISS to {args.nthreads} per process", + condition=(args.rank == 0)) + faiss.omp_set_num_threads(args.nthreads) + + Run.init(args.rank, args.root, args.experiment, args.run) + Run._log_args(args) + Run.info(args.input_arguments.__dict__, '\n') + + return args diff --git a/colbert/utils/runs.py b/colbert/utils/runs.py new file mode 100644 index 0000000000000000000000000000000000000000..79d86b544b7963d5bcf2cd988138c062bb3cfc30 --- /dev/null +++ b/colbert/utils/runs.py @@ -0,0 +1,104 @@ +import os +import sys +import time +import __main__ +import traceback +import mlflow + +import colbert.utils.distributed as distributed + +from contextlib import contextmanager +from colbert.utils.logging import Logger +from colbert.utils.utils import timestamp, create_directory, print_message + + +class _RunManager(): + def __init__(self): + self.experiments_root = None + self.experiment = None + self.path = None + self.script = self._get_script_name() + self.name = self._generate_default_run_name() + self.original_name = self.name + self.exit_status = 'FINISHED' + + self._logger = None + self.start_time = time.time() + + def init(self, rank, root, experiment, name): + assert '/' not in experiment, experiment + assert '/' not in name, name + + self.experiments_root = os.path.abspath(root) + self.experiment = experiment + self.name = name + self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name) + + if rank < 1: + if os.path.exists(self.path): + print('\n\n') + print_message("It seems that ", self.path, " already exists.") + print_message("Do you want to overwrite it? \t yes/no \n") + + # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds. + + response = input() + if response.strip() != 'yes': + assert not os.path.exists(self.path), self.path + else: + create_directory(self.path) + + distributed.barrier(rank) + + self._logger = Logger(rank, self) + self._log_args = self._logger._log_args + self.warn = self._logger.warn + self.info = self._logger.info + self.info_all = self._logger.info_all + self.log_metric = self._logger.log_metric + self.log_new_artifact = self._logger.log_new_artifact + + def _generate_default_run_name(self): + return timestamp() + + def _get_script_name(self): + return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none' + + @contextmanager + def context(self, consider_failed_if_interrupted=True): + try: + yield + + except KeyboardInterrupt as ex: + print('\n\nInterrupted\n\n') + self._logger._log_exception(ex.__class__, ex, ex.__traceback__) + self._logger._log_all_artifacts() + + if consider_failed_if_interrupted: + self.exit_status = 'KILLED' # mlflow.entities.RunStatus.KILLED + + sys.exit(128 + 2) + + except Exception as ex: + self._logger._log_exception(ex.__class__, ex, ex.__traceback__) + self._logger._log_all_artifacts() + + self.exit_status = 'FAILED' # mlflow.entities.RunStatus.FAILED + + raise ex + + finally: + total_seconds = str(time.time() - self.start_time) + '\n' + original_name = str(self.original_name) + name = str(self.name) + + self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds) + self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name) + self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name) + + self._logger._log_all_artifacts() + + mlflow.end_run(status=self.exit_status) + + +Run = _RunManager() diff --git a/colbert/utils/utils.py b/colbert/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c80af838646d99acef304fb6077f2107b80da876 --- /dev/null +++ b/colbert/utils/utils.py @@ -0,0 +1,271 @@ +import os +import tqdm +import torch +import datetime +import itertools + +from multiprocessing import Pool +from collections import OrderedDict, defaultdict + + +def print_message(*s, condition=True): + s = ' '.join([str(x) for x in s]) + msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s) + + if condition: + print(msg, flush=True) + + return msg + + +def timestamp(): + format_str = "%Y-%m-%d_%H.%M.%S" + result = datetime.datetime.now().strftime(format_str) + return result + + +def file_tqdm(file): + print(f"#> Reading {file.name}") + + with tqdm.tqdm(total=os.path.getsize(file.name) / 1024.0 / 1024.0, unit="MiB") as pbar: + for line in file: + yield line + pbar.update(len(line) / 1024.0 / 1024.0) + + pbar.close() + + +def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer, arguments=None): + print(f"#> Saving a checkpoint to {path} ..") + + if hasattr(model, 'module'): + model = model.module # extract model from a distributed/data-parallel wrapper + + checkpoint = {} + checkpoint['epoch'] = epoch_idx + checkpoint['batch'] = mb_idx + checkpoint['model_state_dict'] = model.state_dict() + checkpoint['optimizer_state_dict'] = optimizer.state_dict() + checkpoint['arguments'] = arguments + + torch.save(checkpoint, path) + + +def load_checkpoint(path, model, optimizer=None, do_print=True): + if do_print: + print_message("#> Loading checkpoint", path, "..") + + if path.startswith("http:") or path.startswith("https:"): + checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu') + else: + checkpoint = torch.load(path, map_location='cpu') + + state_dict = checkpoint['model_state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k + if k[:7] == 'module.': + name = k[7:] + new_state_dict[name] = v + + checkpoint['model_state_dict'] = new_state_dict + + try: + model.load_state_dict(checkpoint['model_state_dict']) + except: + print_message("[WARNING] Loading checkpoint with strict=False") + model.load_state_dict(checkpoint['model_state_dict'], strict=False) + + if optimizer: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + if do_print: + print_message("#> checkpoint['epoch'] =", checkpoint['epoch']) + print_message("#> checkpoint['batch'] =", checkpoint['batch']) + + return checkpoint + + +def create_directory(path): + if os.path.exists(path): + print('\n') + print_message("#> Note: Output directory", path, 'already exists\n\n') + else: + print('\n') + print_message("#> Creating directory", path, '\n\n') + os.makedirs(path) + +# def batch(file, bsize): +# while True: +# L = [ujson.loads(file.readline()) for _ in range(bsize)] +# yield L +# return + + +def f7(seq): + """ + Source: https://stackoverflow.com/a/480227/1493011 + """ + + seen = set() + return [x for x in seq if not (x in seen or seen.add(x))] + + +def batch(group, bsize, provide_offset=False): + offset = 0 + while offset < len(group): + L = group[offset: offset + bsize] + yield ((offset, L) if provide_offset else L) + offset += len(L) + return + + +class dotdict(dict): + """ + dot.notation access to dictionary attributes + Credit: derek73 @ https://stackoverflow.com/questions/2352181 + """ + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def flatten(L): + return [x for y in L for x in y] + + +def zipstar(L, lazy=False): + """ + A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...]) + May return lists or tuples. + """ + + if len(L) == 0: + return L + + width = len(L[0]) + + if width < 100: + return [[elem[idx] for elem in L] for idx in range(width)] + + L = zip(*L) + + return L if lazy else list(L) + + +def zip_first(L1, L2): + length = len(L1) if type(L1) in [tuple, list] else None + + L3 = list(zip(L1, L2)) + + assert length in [None, len(L3)], "zip_first() failure: length differs!" + + return L3 + + +def int_or_float(val): + if '.' in val: + return float(val) + + return int(val) + +def load_ranking(path, types=None, lazy=False): + print_message(f"#> Loading the ranked lists from {path} ..") + + try: + lists = torch.load(path) + lists = zipstar([l.tolist() for l in tqdm.tqdm(lists)], lazy=lazy) + except: + if types is None: + types = itertools.cycle([int_or_float]) + + with open(path) as f: + lists = [[typ(x) for typ, x in zip_first(types, line.strip().split('\t'))] + for line in file_tqdm(f)] + + return lists + + +def save_ranking(ranking, path): + lists = zipstar(ranking) + lists = [torch.tensor(l) for l in lists] + + torch.save(lists, path) + + return lists + + +def groupby_first_item(lst): + groups = defaultdict(list) + + for first, *rest in lst: + rest = rest[0] if len(rest) == 1 else rest + groups[first].append(rest) + + return groups + + +def process_grouped_by_first_item(lst): + """ + Requires items in list to already be grouped by first item. + """ + + groups = defaultdict(list) + + started = False + last_group = None + + for first, *rest in lst: + rest = rest[0] if len(rest) == 1 else rest + + if started and first != last_group: + yield (last_group, groups[last_group]) + assert first not in groups, f"{first} seen earlier --- violates precondition." + + groups[first].append(rest) + + last_group = first + started = True + + return groups + + +def grouper(iterable, n, fillvalue=None): + """ + Collect data into fixed-length chunks or blocks + Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + Source: https://docs.python.org/3/library/itertools.html#itertools-recipes + """ + + args = [iter(iterable)] * n + return itertools.zip_longest(*args, fillvalue=fillvalue) + + +# see https://stackoverflow.com/a/45187287 +class NullContextManager(object): + def __init__(self, dummy_resource=None): + self.dummy_resource = dummy_resource + def __enter__(self): + return self.dummy_resource + def __exit__(self, *args): + pass + + +def load_batch_backgrounds(args, qids): + if args.qid2backgrounds is None: + return None + + qbackgrounds = [] + + for qid in qids: + back = args.qid2backgrounds[qid] + + if len(back) and type(back[0]) == int: + x = [args.collection[pid] for pid in back] + else: + x = [args.collectionX.get(pid, '') for pid in back] + + x = ' [SEP] '.join(x) + qbackgrounds.append(x) + + return qbackgrounds diff --git a/docs/images/ColBERT-Framework-MaxSim-W370px.png b/docs/images/ColBERT-Framework-MaxSim-W370px.png new file mode 100644 index 0000000000000000000000000000000000000000..c42e0d67b4707a2c790d548b6a3d28b6fb6e1988 Binary files /dev/null and b/docs/images/ColBERT-Framework-MaxSim-W370px.png differ diff --git a/experiments/MSMARCO-psg/train.py/msmarco.psg.cosine/checkpoints/colbert-400000.dnn b/experiments/MSMARCO-psg/train.py/msmarco.psg.cosine/checkpoints/colbert-400000.dnn new file mode 100644 index 0000000000000000000000000000000000000000..f5804d7b33aa790e5f09f30971ef34c0474e8966 --- /dev/null +++ b/experiments/MSMARCO-psg/train.py/msmarco.psg.cosine/checkpoints/colbert-400000.dnn @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6cc1bdb96728b342bb6595b5d7abd3a8b1f00cd55878a1640d14dc7a2bc04d4d +size 2004344407 diff --git a/fix_colbert_docid.py b/fix_colbert_docid.py new file mode 100644 index 0000000000000000000000000000000000000000..54d385fb2bf33f45d6ad61352bc25daf58c2b358 --- /dev/null +++ b/fix_colbert_docid.py @@ -0,0 +1,46 @@ +import jsonlines +import argparse +import pandas as pd +from tqdm import tqdm + +parser = argparse.ArgumentParser(description=__doc__, + formatter_class=lambda prog: argparse.HelpFormatter(prog, width=100)) +parser.add_argument('--corpus', metavar='FILE', type=str, required=True, help='Corpus file in jsonl') +parser.add_argument('--input_ranking', metavar='FILE', type=str, required=True, help='Ranking file from ColBERT in tsv') +parser.add_argument('--output_ranking', metavar='FILE', type=str, required=True, help='Ranking file with robust doc ids in tsv') +args = parser.parse_args() + + +with jsonlines.open(args.corpus,'r') as reader: + doc_ids = [obj['id'] for obj in reader] + +df = pd.read_csv(args.input_ranking, sep='\t', header=None, names=['query_id', 'doc_id', 'rank']) +df['doc_id'] = df['doc_id'].apply(lambda x: doc_ids[int(x)]) +df['score'] = 1 / df['rank'] + +df = df.sort_values(by='score', ascending=False) +df = df.drop_duplicates(subset=['query_id', 'doc_id']) +df = df.groupby('query_id').head(1000) +df['rank'] = df.groupby('query_id').cumcount() +df = df.sort_values(['query_id','rank']) + +with open(args.output_ranking,'w') as writer: + for _, obj in df.iterrows(): + query_id, doc_id, rank, score = obj['query_id'], obj['doc_id'], obj['rank'], obj['score'] + writer.write(f'{query_id}\tQ0\t{doc_id}\t{rank}\t{score}\tColBERT\n') + + + + + + + + # with open(args.input_ranking, 'r') as reader_ranking: + # with open(args.output_ranking,'w') as writer: + + + + # for obj in tqdm(reader_ranking): + # query_id, doc_idx, rank = obj.replace('\n', '').split('\t') + # doc_id = doc_ids[int(doc_idx)] + # writer.write(f'{query_id}\tQ0\t{doc_id}\t{rank}\n') diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..06d3f5fc6f3c4c9d044e7f75955fc27b8118d808 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +import setuptools + +with open('README.md', 'r') as f: + long_description = f.read() + +setuptools.setup( + name='ColBERT', + version='0.2.0', + author='Omar Khattab', + author_email='okhattab@stanford.edu', + description="Efficient and Effective Passage Search via Contextualized Late Interaction over BERT", + long_description=long_description, + long_description_content_type='text/markdown', + url='https://github.com/stanford-futuredata/ColBERT', + packages=setuptools.find_packages(), + python_requires='>=3.6', +) diff --git a/utility/evaluate/annotate_EM.py b/utility/evaluate/annotate_EM.py new file mode 100644 index 0000000000000000000000000000000000000000..1c21f5b7a3f3cf146006ad4d9f8e5711e9415989 --- /dev/null +++ b/utility/evaluate/annotate_EM.py @@ -0,0 +1,81 @@ +import os +import sys +import git +import tqdm +import ujson +import random + +from argparse import ArgumentParser +from multiprocessing import Pool + +from colbert.utils.utils import print_message, load_ranking, groupby_first_item +from utility.utils.qa_loaders import load_qas_, load_collection_ +from utility.utils.save_metadata import format_metadata, get_metadata +from utility.evaluate.annotate_EM_helpers import * + + +# TODO: Tokenize passages in advance, especially if the ranked list is long! This requires changes to the has_answer input, slightly. + +def main(args): + qas = load_qas_(args.qas) + collection = load_collection_(args.collection, retain_titles=True) + rankings = load_ranking(args.ranking) + parallel_pool = Pool(30) + + print_message('#> Tokenize the answers in the Q&As in parallel...') + qas = list(parallel_pool.map(tokenize_all_answers, qas)) + + qid2answers = {qid: tok_answers for qid, _, tok_answers in qas} + assert len(qas) == len(qid2answers), (len(qas), len(qid2answers)) + + print_message('#> Lookup passages from PIDs...') + expanded_rankings = [(qid, pid, rank, collection[pid], qid2answers[qid]) + for qid, pid, rank, *_ in rankings] + + print_message('#> Assign labels in parallel...') + labeled_rankings = list(parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings))) + + # Dump output. + print_message("#> Dumping output to", args.output, "...") + qid2rankings = groupby_first_item(labeled_rankings) + + num_judged_queries, num_ranked_queries = check_sizes(qid2answers, qid2rankings) + + # Evaluation metrics and depths. + success, counts = compute_and_write_labels(args.output, qid2answers, qid2rankings) + + # Dump metrics. + with open(args.output_metrics, 'w') as f: + d = {'num_ranked_queries': num_ranked_queries, 'num_judged_queries': num_judged_queries} + + extra = '__WARNING' if num_judged_queries != num_ranked_queries else '' + d[f'success{extra}'] = {k: v / num_judged_queries for k, v in success.items()} + d[f'counts{extra}'] = {k: v / num_judged_queries for k, v in counts.items()} + d['arguments'] = get_metadata(args) + + f.write(format_metadata(d) + '\n') + + print('\n\n') + print(args.output) + print(args.output_metrics) + print("#> Done\n") + + +if __name__ == "__main__": + random.seed(12345) + + parser = ArgumentParser(description='.') + + # Input / Output Arguments + parser.add_argument('--qas', dest='qas', required=True, type=str) + parser.add_argument('--collection', dest='collection', required=True, type=str) + parser.add_argument('--ranking', dest='ranking', required=True, type=str) + + args = parser.parse_args() + + args.output = f'{args.ranking}.annotated' + args.output_metrics = f'{args.ranking}.annotated.metrics' + + assert not os.path.exists(args.output), args.output + + main(args) diff --git a/utility/evaluate/annotate_EM_helpers.py b/utility/evaluate/annotate_EM_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..11fdd9df6b13ab9ee41ffe86fd154d361f2ce93a --- /dev/null +++ b/utility/evaluate/annotate_EM_helpers.py @@ -0,0 +1,74 @@ +from colbert.utils.utils import print_message +from utility.utils.dpr import DPR_normalize, has_answer + + +def tokenize_all_answers(args): + qid, question, answers = args + return qid, question, [DPR_normalize(ans) for ans in answers] + + +def assign_label_to_passage(args): + idx, (qid, pid, rank, passage, tokenized_answers) = args + + if idx % (1*1000*1000) == 0: + print(idx) + + return qid, pid, rank, has_answer(tokenized_answers, passage) + + +def check_sizes(qid2answers, qid2rankings): + num_judged_queries = len(qid2answers) + num_ranked_queries = len(qid2rankings) + + print_message('num_judged_queries =', num_judged_queries) + print_message('num_ranked_queries =', num_ranked_queries) + + if num_judged_queries != num_ranked_queries: + assert num_ranked_queries <= num_judged_queries + + print('\n\n') + print_message('[WARNING] num_judged_queries != num_ranked_queries') + print('\n\n') + + return num_judged_queries, num_ranked_queries + + +def compute_and_write_labels(output_path, qid2answers, qid2rankings): + cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] + success = {cutoff: 0.0 for cutoff in cutoffs} + counts = {cutoff: 0.0 for cutoff in cutoffs} + + with open(output_path, 'w') as f: + for qid in qid2answers: + if qid not in qid2rankings: + continue + + prev_rank = 0 # ranks should start at one (i.e., and not zero) + labels = [] + + for pid, rank, label in qid2rankings[qid]: + assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) + prev_rank = rank + + labels.append(label) + line = '\t'.join(map(str, [qid, pid, rank, int(label)])) + '\n' + f.write(line) + + for cutoff in cutoffs: + if cutoff != 'all': + success[cutoff] += sum(labels[:cutoff]) > 0 + counts[cutoff] += sum(labels[:cutoff]) + else: + success[cutoff] += sum(labels) > 0 + counts[cutoff] += sum(labels) + + return success, counts + + +# def dump_metrics(f, nqueries, cutoffs, success, counts): +# for cutoff in cutoffs: +# success_log = "#> P@{} = {}".format(cutoff, success[cutoff] / nqueries) +# counts_log = "#> D@{} = {}".format(cutoff, counts[cutoff] / nqueries) +# print('\n'.join([success_log, counts_log]) + '\n') + +# f.write('\n'.join([success_log, counts_log]) + '\n\n') diff --git a/utility/evaluate/msmarco_passages.py b/utility/evaluate/msmarco_passages.py new file mode 100644 index 0000000000000000000000000000000000000000..0db3c65e9c1552342e15bcf1de7bb2cac8e3c802 --- /dev/null +++ b/utility/evaluate/msmarco_passages.py @@ -0,0 +1,126 @@ +""" + Evaluate MS MARCO Passages ranking. +""" + +import os +import math +import tqdm +import ujson +import random + +from argparse import ArgumentParser +from collections import defaultdict +from colbert.utils.utils import print_message, file_tqdm + + +def main(args): + qid2positives = defaultdict(list) + qid2ranking = defaultdict(list) + qid2mrr = {} + qid2recall = {depth: {} for depth in [50, 200, 1000]} + + with open(args.qrels) as f: + print_message(f"#> Loading QRELs from {args.qrels} ..") + for line in file_tqdm(f): + qid, _, pid, label = map(int, line.strip().split()) + assert label == 1 + + qid2positives[qid].append(pid) + + with open(args.ranking) as f: + print_message(f"#> Loading ranked lists from {args.ranking} ..") + for line in file_tqdm(f): + qid, pid, rank, *score = line.strip().split('\t') + qid, pid, rank = int(qid), int(pid), int(rank) + + if len(score) > 0: + assert len(score) == 1 + score = float(score[0]) + else: + score = None + + qid2ranking[qid].append((rank, pid, score)) + + assert set.issubset(set(qid2ranking.keys()), set(qid2positives.keys())) + + num_judged_queries = len(qid2positives) + num_ranked_queries = len(qid2ranking) + + if num_judged_queries != num_ranked_queries: + print() + print_message("#> [WARNING] num_judged_queries != num_ranked_queries") + print_message(f"#> {num_judged_queries} != {num_ranked_queries}") + print() + + print_message(f"#> Computing MRR@10 for {num_judged_queries} queries.") + + for qid in tqdm.tqdm(qid2positives): + ranking = qid2ranking[qid] + positives = qid2positives[qid] + + for rank, (_, pid, _) in enumerate(ranking): + rank = rank + 1 # 1-indexed + + if pid in positives: + if rank <= 10: + qid2mrr[qid] = 1.0 / rank + break + + for rank, (_, pid, _) in enumerate(ranking): + rank = rank + 1 # 1-indexed + + if pid in positives: + for depth in qid2recall: + if rank <= depth: + qid2recall[depth][qid] = qid2recall[depth].get(qid, 0) + 1.0 / len(positives) + + assert len(qid2mrr) <= num_ranked_queries, (len(qid2mrr), num_ranked_queries) + + print() + mrr_10_sum = sum(qid2mrr.values()) + print_message(f"#> MRR@10 = {mrr_10_sum / num_judged_queries}") + print_message(f"#> MRR@10 (only for ranked queries) = {mrr_10_sum / num_ranked_queries}") + print() + + for depth in qid2recall: + assert len(qid2recall[depth]) <= num_ranked_queries, (len(qid2recall[depth]), num_ranked_queries) + + print() + metric_sum = sum(qid2recall[depth].values()) + print_message(f"#> Recall@{depth} = {metric_sum / num_judged_queries}") + print_message(f"#> Recall@{depth} (only for ranked queries) = {metric_sum / num_ranked_queries}") + print() + + if args.annotate: + print_message(f"#> Writing annotations to {args.output} ..") + + with open(args.output, 'w') as f: + for qid in tqdm.tqdm(qid2positives): + ranking = qid2ranking[qid] + positives = qid2positives[qid] + + for rank, (_, pid, score) in enumerate(ranking): + rank = rank + 1 # 1-indexed + label = int(pid in positives) + + line = [qid, pid, rank, score, label] + line = [x for x in line if x is not None] + line = '\t'.join(map(str, line)) + '\n' + f.write(line) + + +if __name__ == "__main__": + parser = ArgumentParser(description="msmarco_passages.") + + # Input Arguments. + parser.add_argument('--qrels', dest='qrels', required=True, type=str) + parser.add_argument('--ranking', dest='ranking', required=True, type=str) + parser.add_argument('--annotate', dest='annotate', default=False, action='store_true') + + args = parser.parse_args() + + if args.annotate: + args.output = f'{args.ranking}.annotated' + assert not os.path.exists(args.output), args.output + + main(args) diff --git a/utility/preprocess/docs2passages.py b/utility/preprocess/docs2passages.py new file mode 100644 index 0000000000000000000000000000000000000000..0197fc00e80be4a2f40e71364a41900db2b8a888 --- /dev/null +++ b/utility/preprocess/docs2passages.py @@ -0,0 +1,149 @@ +""" + Divide a document collection into N-word/token passage spans (with wrap-around for last passage). +""" + +import os +import math +import ujson +import random + +from multiprocessing import Pool +from argparse import ArgumentParser +from colbert.utils.utils import print_message + +Format1 = 'docid,text' # MS MARCO Passages +Format2 = 'docid,text,title' # DPR Wikipedia +Format3 = 'docid,url,title,text' # MS MARCO Documents + + +def process_page(inp): + """ + Wraps around if we split: make sure last passage isn't too short. + This is meant to be similar to the DPR preprocessing. + """ + + (nwords, overlap, tokenizer), (title_idx, docid, title, url, content) = inp + + if tokenizer is None: + words = content.split() + else: + words = tokenizer.tokenize(content) + + words_ = (words + words) if len(words) > nwords else words + passages = [words_[offset:offset + nwords] for offset in range(0, len(words) - overlap, nwords - overlap)] + + assert all(len(psg) in [len(words), nwords] for psg in passages), (list(map(len, passages)), len(words)) + + if tokenizer is None: + passages = [' '.join(psg) for psg in passages] + else: + passages = [' '.join(psg).replace(' ##', '') for psg in passages] + + if title_idx % 100000 == 0: + print("#> ", title_idx, '\t\t\t', title) + + for p in passages: + print("$$$ ", '\t\t', p) + print() + + print() + print() + print() + + return (docid, title, url, passages) + + +def main(args): + random.seed(12345) + print_message("#> Starting...") + + letter = 'w' if not args.use_wordpiece else 't' + output_path = f'{args.input}.{letter}{args.nwords}_{args.overlap}' + assert not os.path.exists(output_path) + + RawCollection = [] + Collection = [] + + NumIllFormattedLines = 0 + + with open(args.input) as f: + for line_idx, line in enumerate(f): + if line_idx % (100*1000) == 0: + print(line_idx, end=' ') + + title, url = None, None + + try: + line = line.strip().split('\t') + + if args.format == Format1: + docid, doc = line + elif args.format == Format2: + docid, doc, title = line + elif args.format == Format3: + docid, url, title, doc = line + + RawCollection.append((line_idx, docid, title, url, doc)) + except: + NumIllFormattedLines += 1 + + if NumIllFormattedLines % 1000 == 0: + print(f'\n[{line_idx}] NumIllFormattedLines = {NumIllFormattedLines}\n') + + print() + print_message("# of documents is", len(RawCollection), '\n') + + p = Pool(args.nthreads) + + print_message("#> Starting parallel processing...") + + tokenizer = None + if args.use_wordpiece: + from transformers import BertTokenizerFast + tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased') + + process_page_params = [(args.nwords, args.overlap, tokenizer)] * len(RawCollection) + Collection = p.map(process_page, zip(process_page_params, RawCollection)) + + print_message(f"#> Writing to {output_path} ...") + with open(output_path, 'w') as f: + line_idx = 1 + + if args.format == Format1: + f.write('\t'.join(['id', 'text']) + '\n') + elif args.format == Format2: + f.write('\t'.join(['id', 'text', 'title']) + '\n') + elif args.format == Format3: + f.write('\t'.join(['id', 'text', 'title', 'docid']) + '\n') + + for docid, title, url, passages in Collection: + for passage in passages: + if args.format == Format1: + f.write('\t'.join([str(line_idx), passage]) + '\n') + elif args.format == Format2: + f.write('\t'.join([str(line_idx), passage, title]) + '\n') + elif args.format == Format3: + f.write('\t'.join([str(line_idx), passage, title, docid]) + '\n') + + line_idx += 1 + + +if __name__ == "__main__": + parser = ArgumentParser(description="docs2passages.") + + # Input Arguments. + parser.add_argument('--input', dest='input', required=True) + parser.add_argument('--format', dest='format', required=True, choices=[Format1, Format2, Format3]) + + # Output Arguments. + parser.add_argument('--use-wordpiece', dest='use_wordpiece', default=False, action='store_true') + parser.add_argument('--nwords', dest='nwords', default=100, type=int) + parser.add_argument('--overlap', dest='overlap', default=0, type=int) + + # Other Arguments. + parser.add_argument('--nthreads', dest='nthreads', default=28, type=int) + + args = parser.parse_args() + assert args.nwords in range(50, 500) + + main(args) diff --git a/utility/preprocess/queries_split.py b/utility/preprocess/queries_split.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbf240d360c91e06e311540fff256951852c54e --- /dev/null +++ b/utility/preprocess/queries_split.py @@ -0,0 +1,81 @@ +""" + Divide a query set into two. +""" + +import os +import math +import ujson +import random + +from argparse import ArgumentParser +from collections import OrderedDict +from colbert.utils.utils import print_message + + +def main(args): + random.seed(12345) + + """ + Load the queries + """ + Queries = OrderedDict() + + print_message(f"#> Loading queries from {args.input}..") + with open(args.input) as f: + for line in f: + qid, query = line.strip().split('\t') + + assert qid not in Queries + Queries[qid] = query + + """ + Apply the splitting + """ + size_a = len(Queries) - args.holdout + size_b = args.holdout + size_a, size_b = max(size_a, size_b), min(size_a, size_b) + + assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b) + + print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.") + + keys = list(Queries.keys()) + sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b))) + sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices)))) + + assert len(sample_a_indices) == size_a + assert len(sample_b_indices) == size_b + + sample_a = [keys[idx] for idx in sample_a_indices] + sample_b = [keys[idx] for idx in sample_b_indices] + + """ + Write the output + """ + + output_path_a = f'{args.input}.a' + output_path_b = f'{args.input}.b' + + assert not os.path.exists(output_path_a), output_path_a + assert not os.path.exists(output_path_b), output_path_b + + print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...") + + for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]: + with open(output_path, 'w') as f: + for qid in sample: + query = Queries[qid] + line = '\t'.join([qid, query]) + '\n' + f.write(line) + + +if __name__ == "__main__": + parser = ArgumentParser(description="queries_split.") + + # Input Arguments. + parser.add_argument('--input', dest='input', required=True) + parser.add_argument('--holdout', dest='holdout', required=True, type=int) + + args = parser.parse_args() + + main(args) diff --git a/utility/rankings/dev_subsample.py b/utility/rankings/dev_subsample.py new file mode 100644 index 0000000000000000000000000000000000000000..ceae51eb515a4e50ea5dc4544924ec2823f0789f --- /dev/null +++ b/utility/rankings/dev_subsample.py @@ -0,0 +1,47 @@ +import os +import ujson +import random + +from argparse import ArgumentParser + +from colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item +from utility.utils.qa_loaders import load_qas_ + + +def main(args): + print_message("#> Loading all..") + qas = load_qas_(args.qas) + rankings = load_ranking(args.ranking) + qid2rankings = groupby_first_item(rankings) + + print_message("#> Subsampling all..") + qas_sample = random.sample(qas, args.sample) + + with open(args.output, 'w') as f: + for qid, *_ in qas_sample: + for items in qid2rankings[qid]: + items = [qid] + items + line = '\t'.join(map(str, items)) + '\n' + f.write(line) + + print('\n\n') + print(args.output) + print("#> Done.") + + +if __name__ == "__main__": + random.seed(12345) + + parser = ArgumentParser(description='Subsample the dev set.') + parser.add_argument('--qas', dest='qas', required=True, type=str) + parser.add_argument('--ranking', dest='ranking', required=True) + parser.add_argument('--output', dest='output', required=True) + + parser.add_argument('--sample', dest='sample', default=1500, type=int) + + args = parser.parse_args() + + assert not os.path.exists(args.output), args.output + create_directory(os.path.dirname(args.output)) + + main(args) diff --git a/utility/rankings/merge.py b/utility/rankings/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d6f6f14adb01a273d1adad00d840221f169445 --- /dev/null +++ b/utility/rankings/merge.py @@ -0,0 +1,57 @@ +""" + Divide two or more ranking files, by score. +""" + +import os +import tqdm + +from argparse import ArgumentParser +from collections import defaultdict +from colbert.utils.utils import print_message, file_tqdm + + +def main(args): + Rankings = defaultdict(list) + + for path in args.input: + print_message(f"#> Loading the rankings in {path} ..") + + with open(path) as f: + for line in file_tqdm(f): + qid, pid, rank, score = line.strip().split('\t') + qid, pid, rank = map(int, [qid, pid, rank]) + score = float(score) + + Rankings[qid].append((score, rank, pid)) + + with open(args.output, 'w') as f: + print_message(f"#> Writing the output rankings to {args.output} ..") + + for qid in tqdm.tqdm(Rankings): + ranking = sorted(Rankings[qid], reverse=True) + + for rank, (score, original_rank, pid) in enumerate(ranking): + rank = rank + 1 # 1-indexed + + if (args.depth > 0) and (rank > args.depth): + break + + line = [qid, pid, rank, score] + line = '\t'.join(map(str, line)) + '\n' + f.write(line) + + +if __name__ == "__main__": + parser = ArgumentParser(description="merge_rankings.") + + # Input Arguments. + parser.add_argument('--input', dest='input', required=True, nargs='+') + parser.add_argument('--output', dest='output', required=True, type=str) + + parser.add_argument('--depth', dest='depth', required=True, type=int) + + args = parser.parse_args() + + assert not os.path.exists(args.output), args.output + + main(args) diff --git a/utility/rankings/split_by_offset.py b/utility/rankings/split_by_offset.py new file mode 100644 index 0000000000000000000000000000000000000000..d48f597cbbd0b4aecb5d1c99059045468eb97585 --- /dev/null +++ b/utility/rankings/split_by_offset.py @@ -0,0 +1,44 @@ +""" +Split the ranked lists after retrieval with a merged query set. +""" + +import os +import random + +from argparse import ArgumentParser + + +def main(args): + output_paths = ['{}.{}'.format(args.ranking, split) for split in args.names] + assert all(not os.path.exists(path) for path in output_paths), output_paths + + output_files = [open(path, 'w') for path in output_paths] + + with open(args.ranking) as f: + for line in f: + qid, pid, rank, *other = line.strip().split('\t') + qid = int(qid) + split_output_path = output_files[qid // args.gap - 1] + qid = qid % args.gap + + split_output_path.write('\t'.join([str(x) for x in [qid, pid, rank, *other]]) + '\n') + + print(f.name) + + _ = [f.close() for f in output_files] + + print("#> Done!") + + +if __name__ == "__main__": + random.seed(12345) + + parser = ArgumentParser(description='Subsample the dev set.') + parser.add_argument('--ranking', dest='ranking', required=True) + + parser.add_argument('--names', dest='names', required=False, default=['train', 'dev', 'test'], type=str, nargs='+') # order matters! + parser.add_argument('--gap', dest='gap', required=False, default=1_000_000_000, type=int) # larger than any individual query set + + args = parser.parse_args() + + main(args) diff --git a/utility/rankings/split_by_queries.py b/utility/rankings/split_by_queries.py new file mode 100644 index 0000000000000000000000000000000000000000..690dd92473afc357fa97a32720ec1458231a8872 --- /dev/null +++ b/utility/rankings/split_by_queries.py @@ -0,0 +1,67 @@ +import os +import sys +import tqdm +import ujson +import random + +from argparse import ArgumentParser +from collections import OrderedDict +from colbert.utils.utils import print_message, file_tqdm + + +def main(args): + qid_to_file_idx = {} + + for qrels_idx, qrels in enumerate(args.all_queries): + with open(qrels) as f: + for line in f: + qid, *_ = line.strip().split('\t') + qid = int(qid) + + assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx) + qid_to_file_idx[qid] = qrels_idx + + all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))] + + assert all(not os.path.exists(path) for path in all_outputs_paths) + + all_outputs = [open(path, 'w') for path in all_outputs_paths] + + with open(args.ranking) as f: + print_message(f"#> Loading ranked lists from {f.name} ..") + + last_file_idx = -1 + + for line in file_tqdm(f): + qid, *_ = line.strip().split('\t') + + file_idx = qid_to_file_idx[int(qid)] + + if file_idx != last_file_idx: + print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}") + + last_file_idx = file_idx + + all_outputs[file_idx].write(line) + + print() + + for f in all_outputs: + print(f.name) + f.close() + + print("#> Done!") + + +if __name__ == "__main__": + random.seed(12345) + + parser = ArgumentParser(description='.') + + # Input Arguments + parser.add_argument('--ranking', dest='ranking', required=True, type=str) + parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+') + + args = parser.parse_args() + + main(args) diff --git a/utility/rankings/tune.py b/utility/rankings/tune.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4cb16ea2d21d4767cbd2208dcc20ea6d2617c8 --- /dev/null +++ b/utility/rankings/tune.py @@ -0,0 +1,66 @@ +import os +import ujson +import random + +from argparse import ArgumentParser +from colbert.utils.utils import print_message, create_directory +from utility.utils.save_metadata import save_metadata + + +def main(args): + AllMetrics = {} + Scores = {} + + for path in args.paths: + with open(path) as f: + metric = ujson.load(f) + AllMetrics[path] = metric + + for k in args.metric: + metric = metric[k] + + assert type(metric) is float + Scores[path] = metric + + MaxKey = max(Scores, key=Scores.get) + + MaxCKPT = int(MaxKey.split('/')[-2].split('.')[-1]) + MaxARGS = os.path.join(os.path.dirname(MaxKey), 'logs', 'args.json') + + with open(MaxARGS) as f: + logs = ujson.load(f) + MaxCHECKPOINT = logs['checkpoint'] + + assert MaxCHECKPOINT.endswith(f'colbert-{MaxCKPT}.dnn'), (MaxCHECKPOINT, MaxCKPT) + + with open(args.output, 'w') as f: + f.write(MaxCHECKPOINT) + + args.Scores = Scores + args.AllMetrics = AllMetrics + + save_metadata(f'{args.output}.meta', args) + + print('\n\n', args, '\n\n') + print(args.output) + print_message("#> Done.") + + +if __name__ == "__main__": + random.seed(12345) + + parser = ArgumentParser(description='.') + + # Input / Output Arguments + parser.add_argument('--metric', dest='metric', required=True, type=str) # e.g., success.20 + parser.add_argument('--paths', dest='paths', required=True, type=str, nargs='+') + parser.add_argument('--output', dest='output', required=True, type=str) + + args = parser.parse_args() + + args.metric = args.metric.split('.') + + assert not os.path.exists(args.output), args.output + create_directory(os.path.dirname(args.output)) + + main(args) diff --git a/utility/supervision/self_training.py b/utility/supervision/self_training.py new file mode 100644 index 0000000000000000000000000000000000000000..f30dea4d1aa556fa84d62bb176dcb16d5c2eb331 --- /dev/null +++ b/utility/supervision/self_training.py @@ -0,0 +1,123 @@ +import os +import sys +import git +import tqdm +import ujson +import random + +from argparse import ArgumentParser +from colbert.utils.utils import print_message, load_ranking, groupby_first_item + + +MAX_NUM_TRIPLES = 40_000_000 + + +def sample_negatives(negatives, num_sampled, biased=False): + num_sampled = min(len(negatives), num_sampled) + + if biased: + assert num_sampled % 2 == 0 + num_sampled_top100 = num_sampled // 2 + num_sampled_rest = num_sampled - num_sampled_top100 + + return random.sample(negatives[:100], num_sampled_top100) + random.sample(negatives[100:], num_sampled_rest) + + return random.sample(negatives, num_sampled) + + +def sample_for_query(qid, ranking, npositives, depth_positive, depth_negative, cutoff_negative): + """ + Requires that the ranks are sorted per qid. + """ + assert npositives <= depth_positive < cutoff_negative < depth_negative + + positives, negatives, triples = [], [], [] + + for pid, rank, *_ in ranking: + assert rank >= 1, f"ranks should start at 1 \t\t got rank = {rank}" + + if rank > depth_negative: + break + + if rank <= depth_positive: + positives.append(pid) + elif rank > cutoff_negative: + negatives.append(pid) + + num_sampled = 100 + + for neg in sample_negatives(negatives, num_sampled): + positives_ = random.sample(positives, npositives) + positives_ = positives_[0] if npositives == 1 else positives_ + triples.append((qid, positives_, neg)) + + return triples + + +def main(args): + rankings = load_ranking(args.ranking, types=[int, int, int, float, int]) + + print_message("#> Group by QID") + qid2rankings = groupby_first_item(tqdm.tqdm(rankings)) + + Triples = [] + NonEmptyQIDs = 0 + + for processing_idx, qid in enumerate(qid2rankings): + l = sample_for_query(qid, qid2rankings[qid], args.positives, args.depth_positive, args.depth_negative, args.cutoff_negative) + NonEmptyQIDs += (len(l) > 0) + Triples.extend(l) + + if processing_idx % (10_000) == 0: + print_message(f"#> Done with {processing_idx+1} questions!\t\t " + f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") + + print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") + print_message(f"#> len(Triples) = {len(Triples)}") + + if len(Triples) > MAX_NUM_TRIPLES: + Triples = random.sample(Triples, MAX_NUM_TRIPLES) + + ### Prepare the triples ### + print_message("#> Shuffling the triples...") + random.shuffle(Triples) + + print_message("#> Writing {}M examples to file.".format(len(Triples) / 1000.0 / 1000.0)) + + with open(args.output, 'w') as f: + for example in Triples: + ujson.dump(example, f) + f.write('\n') + + with open(f'{args.output}.meta', 'w') as f: + args.cmd = ' '.join(sys.argv) + args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha + ujson.dump(args.__dict__, f, indent=4) + f.write('\n') + + print('\n\n', args, '\n\n') + print(args.output) + print_message("#> Done.") + + +if __name__ == "__main__": + random.seed(12345) + + parser = ArgumentParser(description='Create training triples from ranked list.') + + # Input / Output Arguments + parser.add_argument('--ranking', dest='ranking', required=True, type=str) + parser.add_argument('--output', dest='output', required=True, type=str) + + # Weak Supervision Arguments. + parser.add_argument('--positives', dest='positives', required=True, type=int) + parser.add_argument('--depth+', dest='depth_positive', required=True, type=int) + + parser.add_argument('--depth-', dest='depth_negative', required=True, type=int) + parser.add_argument('--cutoff-', dest='cutoff_negative', required=True, type=int) + + args = parser.parse_args() + + assert not os.path.exists(args.output), args.output + + main(args) diff --git a/utility/supervision/triples.py b/utility/supervision/triples.py new file mode 100644 index 0000000000000000000000000000000000000000..0bbfc9a3e1e4fbe05951dd2d26868174f9c5ef44 --- /dev/null +++ b/utility/supervision/triples.py @@ -0,0 +1,150 @@ +""" + Example: --positives 5,50 1,1000 ~~> best-5 (in top-50) + best-1 (in top-1000) +""" + +import os +import sys +import git +import tqdm +import ujson +import random + +from argparse import ArgumentParser +from colbert.utils.utils import print_message, load_ranking, groupby_first_item, create_directory +from utility.utils.save_metadata import save_metadata + + +MAX_NUM_TRIPLES = 40_000_000 + + +def sample_negatives(negatives, num_sampled, biased=None): + assert biased in [None, 100, 200], "NOTE: We bias 50% from the top-200 negatives, if there are twice or more." + + num_sampled = min(len(negatives), num_sampled) + + if biased and num_sampled < len(negatives): + assert num_sampled % 2 == 0, num_sampled + + num_sampled_top100 = num_sampled // 2 + num_sampled_rest = num_sampled - num_sampled_top100 + + oversampled, undersampled = negatives[:biased], negatives[biased:] + + if len(oversampled) < len(undersampled): + return random.sample(oversampled, num_sampled_top100) + random.sample(undersampled, num_sampled_rest) + + return random.sample(negatives, num_sampled) + + +def sample_for_query(qid, ranking, args_positives, depth, permissive, biased): + """ + Requires that the ranks are sorted per qid. + """ + + positives, negatives, triples = [], [], [] + + for pid, rank, *_, label in ranking: + assert rank >= 1, f"ranks should start at 1 \t\t got rank = {rank}" + assert label in [0, 1] + + if rank > depth: + break + + if label: + take_this_positive = any(rank <= maxDepth and len(positives) < maxBest for maxBest, maxDepth in args_positives) + + if take_this_positive: + positives.append((pid, 0)) + elif permissive: + positives.append((pid, rank)) # utilize with a few negatives, starting at (next) rank + + else: + negatives.append(pid) + + for pos, neg_start in positives: + num_sampled = 100 if neg_start == 0 else 5 + negatives_ = negatives[neg_start:] + + biased_ = biased if neg_start == 0 else None + for neg in sample_negatives(negatives_, num_sampled, biased=biased_): + triples.append((qid, pos, neg)) + + return triples + + +def main(args): + try: + rankings = load_ranking(args.ranking, types=[int, int, int, float, int]) + except: + rankings = load_ranking(args.ranking, types=[int, int, int, int]) + + print_message("#> Group by QID") + qid2rankings = groupby_first_item(tqdm.tqdm(rankings)) + + Triples = [] + NonEmptyQIDs = 0 + + for processing_idx, qid in enumerate(qid2rankings): + l = sample_for_query(qid, qid2rankings[qid], args.positives, args.depth, args.permissive, args.biased) + NonEmptyQIDs += (len(l) > 0) + Triples.extend(l) + + if processing_idx % (10_000) == 0: + print_message(f"#> Done with {processing_idx+1} questions!\t\t " + f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") + + print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") + print_message(f"#> len(Triples) = {len(Triples)}") + + if len(Triples) > MAX_NUM_TRIPLES: + Triples = random.sample(Triples, MAX_NUM_TRIPLES) + + ### Prepare the triples ### + print_message("#> Shuffling the triples...") + random.shuffle(Triples) + + print_message("#> Writing {}M examples to file.".format(len(Triples) / 1000.0 / 1000.0)) + + with open(args.output, 'w') as f: + for example in Triples: + ujson.dump(example, f) + f.write('\n') + + save_metadata(f'{args.output}.meta', args) + + print('\n\n', args, '\n\n') + print(args.output) + print_message("#> Done.") + + +if __name__ == "__main__": + parser = ArgumentParser(description='Create training triples from ranked list.') + + # Input / Output Arguments + parser.add_argument('--ranking', dest='ranking', required=True, type=str) + parser.add_argument('--output', dest='output', required=True, type=str) + + # Weak Supervision Arguments. + parser.add_argument('--positives', dest='positives', required=True, nargs='+') + parser.add_argument('--depth', dest='depth', required=True, type=int) # for negatives + + parser.add_argument('--permissive', dest='permissive', default=False, action='store_true') + # parser.add_argument('--biased', dest='biased', default=False, action='store_true') + parser.add_argument('--biased', dest='biased', default=None, type=int) + parser.add_argument('--seed', dest='seed', required=False, default=12345, type=int) + + args = parser.parse_args() + random.seed(args.seed) + + assert not os.path.exists(args.output), args.output + + args.positives = [list(map(int, configuration.split(','))) for configuration in args.positives] + + assert all(len(x) == 2 for x in args.positives) + assert all(maxBest <= maxDepth for maxBest, maxDepth in args.positives), args.positives + + create_directory(os.path.dirname(args.output)) + + assert args.biased in [None, 100, 200] + + main(args) diff --git a/utility/utils/dpr.py b/utility/utils/dpr.py new file mode 100644 index 0000000000000000000000000000000000000000..205c093d67f048455332007ec2e2fa43392dcecb --- /dev/null +++ b/utility/utils/dpr.py @@ -0,0 +1,237 @@ +""" + Source: DPR Implementation from Facebook Research + https://github.com/facebookresearch/DPR/tree/master/dpr +""" + +import string +import spacy +import regex +import unicodedata + + +class Tokens(object): + """A class to represent a list of tokenized text.""" + TEXT = 0 + TEXT_WS = 1 + SPAN = 2 + POS = 3 + LEMMA = 4 + NER = 5 + + def __init__(self, data, annotators, opts=None): + self.data = data + self.annotators = annotators + self.opts = opts or {} + + def __len__(self): + """The number of tokens.""" + return len(self.data) + + def slice(self, i=None, j=None): + """Return a view of the list of tokens from [i, j).""" + new_tokens = copy.copy(self) + new_tokens.data = self.data[i: j] + return new_tokens + + def untokenize(self): + """Returns the original text (with whitespace reinserted).""" + return ''.join([t[self.TEXT_WS] for t in self.data]).strip() + + def words(self, uncased=False): + """Returns a list of the text of each token + + Args: + uncased: lower cases text + """ + if uncased: + return [t[self.TEXT].lower() for t in self.data] + else: + return [t[self.TEXT] for t in self.data] + + def offsets(self): + """Returns a list of [start, end) character offsets of each token.""" + return [t[self.SPAN] for t in self.data] + + def pos(self): + """Returns a list of part-of-speech tags of each token. + Returns None if this annotation was not included. + """ + if 'pos' not in self.annotators: + return None + return [t[self.POS] for t in self.data] + + def lemmas(self): + """Returns a list of the lemmatized text of each token. + Returns None if this annotation was not included. + """ + if 'lemma' not in self.annotators: + return None + return [t[self.LEMMA] for t in self.data] + + def entities(self): + """Returns a list of named-entity-recognition tags of each token. + Returns None if this annotation was not included. + """ + if 'ner' not in self.annotators: + return None + return [t[self.NER] for t in self.data] + + def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): + """Returns a list of all ngrams from length 1 to n. + + Args: + n: upper limit of ngram length + uncased: lower cases text + filter_fn: user function that takes in an ngram list and returns + True or False to keep or not keep the ngram + as_string: return the ngram as a string vs list + """ + + def _skip(gram): + if not filter_fn: + return False + return filter_fn(gram) + + words = self.words(uncased) + ngrams = [(s, e + 1) + for s in range(len(words)) + for e in range(s, min(s + n, len(words))) + if not _skip(words[s:e + 1])] + + # Concatenate into strings + if as_strings: + ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] + + return ngrams + + def entity_groups(self): + """Group consecutive entity tokens with the same NER tag.""" + entities = self.entities() + if not entities: + return None + non_ent = self.opts.get('non_ent', 'O') + groups = [] + idx = 0 + while idx < len(entities): + ner_tag = entities[idx] + # Check for entity tag + if ner_tag != non_ent: + # Chomp the sequence + start = idx + while (idx < len(entities) and entities[idx] == ner_tag): + idx += 1 + groups.append((self.slice(start, idx).untokenize(), ner_tag)) + else: + idx += 1 + return groups + + +class Tokenizer(object): + """Base tokenizer class. + Tokenizers implement tokenize, which should return a Tokens class. + """ + + def tokenize(self, text): + raise NotImplementedError + + def shutdown(self): + pass + + def __del__(self): + self.shutdown() + + +class SimpleTokenizer(Tokenizer): + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' + NON_WS = r'[^\p{Z}\p{C}]' + + def __init__(self, **kwargs): + """ + Args: + annotators: None or empty set (only tokenizes). + """ + self._regexp = regex.compile( + '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) + + +def has_answer(tokenized_answers, text): + text = DPR_normalize(text) + + for single_answer in tokenized_answers: + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i: i + len(single_answer)]: + return True + + return False + + +def locate_answers(tokenized_answers, text): + """ + Returns each occurrence of an answer as (offset, endpos) in terms of *characters*. + """ + tokenized_text = DPR_tokenize(text) + occurrences = [] + + text_words, text_word_positions = tokenized_text.words(uncased=True), tokenized_text.offsets() + answers_words = [ans.words(uncased=True) for ans in tokenized_answers] + + for single_answer in answers_words: + for i in range(0, len(text_words) - len(single_answer) + 1): + if single_answer == text_words[i: i + len(single_answer)]: + (offset, _), (_, endpos) = text_word_positions[i], text_word_positions[i+len(single_answer)-1] + occurrences.append((offset, endpos)) + + return occurrences + + +STokenizer = SimpleTokenizer() + + +def DPR_tokenize(text): + return STokenizer.tokenize(unicodedata.normalize('NFD', text)) + + +def DPR_normalize(text): + return DPR_tokenize(text).words(uncased=True) + + +# Source: https://github.com/shmsw25/qa-hard-em/blob/master/prepro_util.py +def strip_accents(text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) diff --git a/utility/utils/qa_loaders.py b/utility/utils/qa_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..10c7f459bc6d36ba07afd467946487122d61aa62 --- /dev/null +++ b/utility/utils/qa_loaders.py @@ -0,0 +1,33 @@ +import os +import ujson + +from collections import defaultdict +from colbert.utils.utils import print_message, file_tqdm + + +def load_collection_(path, retain_titles): + with open(path) as f: + collection = [] + + for line in file_tqdm(f): + _, passage, title = line.strip().split('\t') + + if retain_titles: + passage = title + ' | ' + passage + + collection.append(passage) + + return collection + + +def load_qas_(path): + print_message("#> Loading the reference QAs from", path) + + triples = [] + + with open(path) as f: + for line in f: + qa = ujson.loads(line) + triples.append((qa['qid'], qa['question'], qa['answers'])) + + return triples diff --git a/utility/utils/save_metadata.py b/utility/utils/save_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..71dd2c4b3e64fad4ae4ad35bfb8b5b31bb075979 --- /dev/null +++ b/utility/utils/save_metadata.py @@ -0,0 +1,41 @@ +import os +import sys +import git +import time +import copy +import ujson +import socket + + +def get_metadata(args): + args = copy.deepcopy(args) + + args.hostname = socket.gethostname() + args.git_branch = git.Repo(search_parent_directories=True).active_branch.name + args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha + args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) + args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') + args.cmd = ' '.join(sys.argv) + + try: + args.input_arguments = copy.deepcopy(args.input_arguments.__dict__) + except: + args.input_arguments = None + + return dict(args.__dict__) + + +def format_metadata(metadata): + assert type(metadata) == dict + + return ujson.dumps(metadata, indent=4) + + +def save_metadata(path, args): + assert not os.path.exists(path), path + + with open(path, 'w') as output_metadata: + data = get_metadata(args) + output_metadata.write(format_metadata(data) + '\n') + + return data