import typing import lancedb import os import gradio as gr from sentence_transformers import SentenceTransformer from FlagEmbedding import FlagReranker from functools import wraps import time def timeit(func): @wraps(func) def timeit_wrapper(*args, **kwargs): start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time print(f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') return result return timeit_wrapper db = lancedb.connect(".lancedb") TABLE = db.open_table(os.getenv("TABLE_NAME")) VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") TOP_K = int(os.getenv("TOP_K", 5)) TOP_K_RERANK = int(os.getenv("TOP_K_RERANK", 2)) BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large") retriever = SentenceTransformer(os.getenv("EMB_MODEL")) reranker = FlagReranker(RERANK_MODEL, use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation def rerank(query: str, documents: typing.List[str], k: int): data_for_reranker = [(query, document) for document in documents] scores = reranker.compute_score(data_for_reranker, batch_size=BATCH_SIZE) indices_scores = [(i, score) for (i, score) in enumerate(scores)] indices_scores.sort(key=lambda x: x[1], reverse=True) best_indices = list(map(lambda x: x[0], indices_scores[:k])) return [documents[i] for i in best_indices] @timeit def retrieve(query, k): query_vec = retriever.encode(query) try: documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list() documents = [doc[TEXT_COLUMN] for doc in documents] documents = rerank(query, documents, TOP_K_RERANK) return documents except Exception as e: raise gr.Error(str(e)) if __name__ == "__main__": retrieve("What is RAG?", TOP_K)