import logging from typing import List, Dict, Any from langchain.prompts import PromptTemplate from langchain_community.vectorstores import FAISS from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.documents import Document from expiringdict import ExpiringDict from core.data_loader import DataLoader from core.user_profile import UserProfile from config.config import settings logger = logging.getLogger(__name__) class RAGEngine: """ The core Retrieval-Augmented Generation engine for the TravelMate chatbot. This class handles model initialization, vector store creation, and query processing. """ def __init__(self, user_profile: UserProfile): """ Initializes the RAG engine, loading models and setting up the QA chain. """ self.user_profile = user_profile self.query_cache = ExpiringDict(max_len=settings.MAX_CACHE_SIZE, max_age_seconds=settings.CACHE_TTL) try: self.embeddings = self._initialize_embeddings() self.vector_store = self._initialize_vector_store() self.llm = self._initialize_llm() self.qa_chain = self._create_rag_chain() logger.info("RAG Engine initialized successfully.") except Exception as e: logger.critical(f"Failed to initialize RAG Engine: {e}", exc_info=True) raise def _initialize_embeddings(self) -> HuggingFaceEmbeddings: """Initializes the sentence-transformer embeddings model.""" return HuggingFaceEmbeddings( model_name=settings.EMBEDDING_MODEL_NAME, model_kwargs={'device': 'cpu'} ) def _initialize_vector_store(self) -> FAISS: """ Initializes the FAISS vector store. Loads from disk if it exists, otherwise creates it from the data loader. """ if settings.VECTOR_STORE_DIR.exists() and any(settings.VECTOR_STORE_DIR.iterdir()): logger.info(f"Loading existing vector store from {settings.VECTOR_STORE_DIR}...") return FAISS.load_local( folder_path=str(settings.VECTOR_STORE_DIR), embeddings=self.embeddings, allow_dangerous_deserialization=True ) else: logger.info("Creating new vector store from scratch.") data_loader = DataLoader() documents = data_loader.load_documents() if not documents: raise ValueError("No documents were loaded. Cannot create vector store.") vector_store = FAISS.from_documents(documents, self.embeddings) logger.info(f"Saving new vector store to {settings.VECTOR_STORE_DIR}...") vector_store.save_local(str(settings.VECTOR_STORE_DIR)) return vector_store def _initialize_llm(self) -> HuggingFaceEndpoint: """Initializes the Hugging Face Inference Endpoint for the LLM.""" if not settings.HUGGINGFACE_API_TOKEN: raise ValueError("HUGGINGFACE_API_TOKEN is not set.") return HuggingFaceEndpoint( repo_id=settings.MODEL_NAME, huggingfacehub_api_token=settings.HUGGINGFACE_API_TOKEN, temperature=settings.TEMPERATURE, max_new_tokens=settings.MAX_NEW_TOKENS, repetition_penalty=settings.REPETITION_PENALTY, ) def _create_rag_chain(self): """Creates a modern, streamlined RAG chain for question answering.""" qa_prompt = PromptTemplate.from_template(settings.QA_PROMPT_TEMPLATE) question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt) retriever = self.vector_store.as_retriever( search_type="similarity_score_threshold", search_kwargs={'k': settings.TOP_K_RESULTS, 'score_threshold': settings.SIMILARITY_THRESHOLD} ) rag_chain = create_retrieval_chain(retriever, question_answer_chain) return rag_chain def _format_sources(self, sources: List[Document]) -> List[Dict[str, Any]]: """Formats source documents into a serializable list of dictionaries.""" if not sources: return [] formatted_list = [] for source in sources: metadata = source.metadata source_name = metadata.get('source', 'Unknown Source') if source_name == 'huggingface': title = f"Dataset: {metadata.get('intent', 'N/A')}" category = metadata.get('category', 'N/A') elif source_name == 'local_guides': title = f"Guide: {metadata.get('title', 'N/A')}" category = metadata.get('category', 'N/A') else: title = "Unknown Source" category = "N/A" formatted_list.append({"title": title, "category": category}) return formatted_list async def process_query(self, query: str, user_id: str) -> Dict[str, Any]: """Processes a user query asynchronously using the streamlined RAG chain.""" cache_key = f"{user_id}:{query}" if cache_key in self.query_cache: logger.info(f"Returning cached response for query: {query}") return self.query_cache[cache_key] logger.info(f"Processing query for user {user_id}: {query}") # The new chain expects 'input' instead of 'question' chain_input = {"input": query} try: result = await self.qa_chain.ainvoke(chain_input) answer = result.get("answer", "Sorry, I couldn't find an answer.") # The new chain returns retrieved documents in the 'context' key sources = self._format_sources(result.get("context", [])) response = {"answer": answer, "sources": sources} self.query_cache[cache_key] = response logger.info(f"Successfully processed query for user {user_id}") return response except Exception as e: logger.error(f"Error processing query for user {user_id}: {e}", exc_info=True) return {"answer": "I'm sorry, but I encountered an error while processing your request.", "sources": []}