|
|
|
from typing import List,Optional |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings.base import Embeddings |
|
from langchain_community.vectorstores.utils import DistanceStrategy |
|
from transformers import RagRetriever |
|
from langchain.docstore.document import Document as LangchainDocument |
|
|
|
def init_vectorDB_from_doc(documents:List[LangchainDocument], embedding_model: Embeddings) -> FAISS: |
|
KNOWLEDGE_VECTOR_DATABASE = FAISS.from_documents( |
|
documents, embedding_model, distance_strategy=DistanceStrategy.COSINE |
|
) |
|
return KNOWLEDGE_VECTOR_DATABASE |
|
|
|
|
|
|
|
|
|
|
|
def retriever( |
|
user_query: str, |
|
vectorDB: FAISS, |
|
reranker = None, |
|
num_doc_before_rerank: int = 5, |
|
num_final_relevant_docs: int = 5, |
|
rerank: bool = True |
|
) -> List[str]: |
|
relevant_docs = vectorDB.similarity_search(query=user_query, k=num_doc_before_rerank) |
|
relevant_docs = [doc.page_content for doc in relevant_docs] |
|
print("=> Relevant documents:") |
|
print(relevant_docs) |
|
if rerank and reranker: |
|
|
|
relevant_docs = reranker.rerank(user_query, relevant_docs, k=num_final_relevant_docs) |
|
final_relevant_docs = [doc["content"] for doc in relevant_docs] |
|
print("=> Reranked documents:") |
|
print(final_relevant_docs) |
|
else: |
|
final_relevant_docs = relevant_docs |
|
print("=> Final relevant documents:") |
|
print(final_relevant_docs) |
|
return final_relevant_docs |
|
|