import os import ujson import torch import random from collections import defaultdict, OrderedDict from colbert.parameters import DEVICE from colbert.modeling.colbert import ColBERT from colbert.utils.utils import print_message, load_checkpoint def load_model(args, do_print=True): colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased', query_maxlen=args.query_maxlen, doc_maxlen=args.doc_maxlen, dim=args.dim, similarity_metric=args.similarity, mask_punctuation=args.mask_punctuation) colbert = colbert.to(DEVICE) print_message("#> Loading model checkpoint.", condition=do_print) checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) colbert.eval() return colbert, checkpoint