fastapi-rag-qa / app /rag_pipeline /retriever_chain.py
tasmimulhuda's picture
application addd
abb6f94
import logging
#import create_history_aware_retriever,
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from app.rag_pipeline.prompt_utils import qa_prompt
from app.rag_pipeline.chroma_client import get_chroma_client
from app.settings import Config
# from prompt_utils import qa_prompt
# from chroma_client import get_chroma_client
# import sys
# import os
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# sys.path.insert(0, parent_dir)
# from settings import Config
conf = Config()
MODELS_PATH = conf.MODELS_PATH #'/models'
CONTEXT_WINDOW_SIZE = 2048
MAX_NEW_TOKENS = 2048
N_BATCH= 512
N_GPU_LAYERS = 1
MODEL_ID = conf.MODEL_ID #"TheBloke/Mistral-7B-v0.1-GGUF"
MODEL_BASENAME = conf.MODEL_BASENAME # "mistral-7b-v0.1.Q4_0.gguf"
device_type = 'cpu'
logger = logging.getLogger(__name__)
class RetrieverChain:
def __init__(self, collection_name, embedding_function, persist_directory):
try:
self.vector_db = get_chroma_client(collection_name, embedding_function, persist_directory)
except Exception as e:
logger.error(f"Error creating RetrieverChain: {e}")
raise
def get_retriever(self):
try:
retriever = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 5, "fetch_k": 2})
return retriever
except Exception as e:
logger.error(f"Failed to get retriever: {e}")
raise
def get_conversational_rag_chain(self, llm):
try:
if self.get_retriever is None:
logger.error(f"Retriever must not be None")
raise ValueError("Retriever must not be None")
if llm is None:
logger.error(f"Model must not be None")
raise ValueError("Model must not be None")
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
return create_retrieval_chain(self.get_retriever(), question_answer_chain)
except Exception as e:
logger.error(f"Error creating RAG chain: {e}")
raise
def get_relevent_docs(self, user_input):
try:
docs = self.vector_db.as_retriever(search_type="mmr", search_kwargs={"k": 6, "fetch_k": 3}).get_relevant_documents(user_input)
logger.info(f"Relevent documents for {user_input}: {docs}")
# Access the retrieved documents
# print("Relevent Docs")
# for doc in docs:
# print(doc.page_content) # Access the original text
# print(doc.metadata) # Access any metadata associated with the document
# print("Relevent Docs end")
return docs
except Exception as e:
logger.error(f"Error getting response: {e}")
raise
def get_response(self, user_input, llm):
try:
qa_rag_chain = self.get_conversational_rag_chain(llm)
response = qa_rag_chain.invoke({"input": user_input})
return response['answer']
except Exception as e:
logger.error(f"Error getting response: {e}")
raise
# if __name__ == "__main__":
# import os
# from model_initializer import initialize_models
# parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
# openai_api_key = conf.API_KEY
# embedding_model, llm_model = initialize_models(openai_api_key,model_id=MODEL_ID, model_basename=MODEL_BASENAME)
# print(f"embeddi_modelng: {embedding_model}")
# print(f"llm_model: {llm_model}")
# collection_name = 'AI_assignment'
# persist_directory = f'D:/AI Assignment/vector_store'
# print(f"persist_directory: {persist_directory}")
# while True:
# print("Enter query: ")
# user_query = input()
# if user_query.lower() == 'exit':
# break
# retriever_qa = RetrieverChain(
# collection_name=collection_name, embedding_function=embedding_model, persist_directory=persist_directory)
# response = retriever_qa.get_response(user_input = user_query, llm= llm_model)
# print(f"Response: {response}")