|
from .vss import VSS |
|
from .llm_reranker import LLMReranker |
|
from .multi_vss import MultiVSS |
|
from .bm25 import BM25 |
|
from .colbertv2 import Colbertv2 |
|
|
|
def get_model(args, skb, **kwargs): |
|
model_name = args.model |
|
if model_name == 'BM25': |
|
return BM25(skb) |
|
if model_name == 'Colbertv2': |
|
try: |
|
return Colbertv2(skb, |
|
dataset_name=args.dataset, |
|
save_dir=args.output_dir, |
|
download_dir=args.download_dir, |
|
human_generated_eval=args.split=='human_generated_eval', |
|
**kwargs |
|
) |
|
except ImportError: |
|
raise ImportError("Please install the colbert package using `pip install colbert-ai`.") |
|
elif model_name == 'VSS': |
|
return VSS( |
|
skb, |
|
emb_model=args.emb_model, |
|
query_emb_dir=args.query_emb_dir, |
|
candidates_emb_dir=args.node_emb_dir, |
|
device=args.device |
|
) |
|
if model_name == 'MultiVSS': |
|
return MultiVSS( |
|
skb, |
|
emb_model=args.emb_model, |
|
query_emb_dir=args.query_emb_dir, |
|
candidates_emb_dir=args.node_emb_dir, |
|
chunk_emb_dir=args.chunk_emb_dir, |
|
aggregate=args.aggregate, |
|
chunk_size=args.chunk_size, |
|
max_k=args.multi_vss_topk, |
|
device=args.device |
|
) |
|
if model_name == 'LLMReranker': |
|
return LLMReranker(skb, |
|
emb_model=args.emb_model, |
|
llm_model=args.llm_model, |
|
query_emb_dir=args.query_emb_dir, |
|
candidates_emb_dir=args.node_emb_dir, |
|
max_cnt = args.max_retry, |
|
max_k=args.llm_topk, |
|
device=args.device |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
raise NotImplementedError(f'{model_name} not implemented') |