TravelMate-AI / core /data_loader.py
bharadwaj-m's picture
First Commit
09aa2b8
import os
import json
import logging
import stat
import time
from typing import Any, List
from config.config import settings
from datasets import load_dataset
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
logger = logging.getLogger(__name__)
class DataLoader:
"""Handles loading and processing of data for the RAG engine."""
def __init__(self):
"""Initialize the data loader."""
self.data_dir = os.path.abspath("data")
self.travel_guides_path = os.path.join(self.data_dir, "travel_guides.json")
self.vector_store_path = os.path.join(self.data_dir, "vector_store", "faiss_index")
self._ensure_data_directories()
self._set_directory_permissions()
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=settings.CHUNK_SIZE,
chunk_overlap=settings.CHUNK_OVERLAP,
length_function=len,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
)
self.max_file_size = 10 * 1024 * 1024 # 10MB
def _ensure_data_directories(self):
"""Ensure necessary data directories exist."""
os.makedirs(self.data_dir, exist_ok=True)
os.makedirs(os.path.dirname(self.vector_store_path), exist_ok=True)
os.makedirs(os.path.join(self.data_dir, "cache"), exist_ok=True)
def _set_directory_permissions(self):
"""Set secure permissions for data directories (755)."""
try:
for dir_path in [
self.data_dir,
os.path.dirname(self.vector_store_path),
os.path.join(self.data_dir, "cache"),
]:
os.chmod(
dir_path,
stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH,
)
except Exception as e:
logger.error(f"Error setting directory permissions: {e}", exc_info=True)
def _validate_file_permissions(self, file_path: str) -> bool:
"""Validate file permissions to ensure security."""
try:
if not os.path.exists(file_path):
return False
file_stat = os.stat(file_path)
if file_stat.st_mode & stat.S_IWOTH: # Disallow world-writable
logger.warning(f"File {file_path} is world-writable. Skipping.")
return False
if file_stat.st_size > self.max_file_size:
logger.warning(f"File {file_path} exceeds size limit. Skipping.")
return False
return True
except Exception as e:
logger.error(f"Error validating file permissions for {file_path}: {e}", exc_info=True)
return False
def _load_dataset_with_retry(self, max_retries: int = 3) -> Any:
"""Load dataset from Hugging Face with an exponential backoff retry mechanism."""
for attempt in range(max_retries):
try:
return load_dataset(
settings.DATASET_ID,
split="train",
cache_dir=os.path.join(self.data_dir, "cache"),
)
except Exception as e:
logger.warning(f"Dataset loading attempt {attempt + 1} failed: {e}")
if attempt == max_retries - 1:
logger.error("All attempts to load dataset failed.")
return None
time.sleep(2 ** attempt)
return None
def load_documents(self) -> List[Document]:
"""Load and process all documents for the knowledge base."""
documents = []
try:
# 1. Load Bitext Travel Dataset
logger.info(f"Loading dataset: {settings.DATASET_ID}")
dataset = self._load_dataset_with_retry()
if dataset:
max_docs = settings.MAX_DOCUMENTS_TO_LOAD
logger.info(f"Loading up to {max_docs} documents from the dataset.")
for i, item in enumerate(dataset):
if i >= max_docs:
logger.info(f"Reached document limit ({max_docs}).")
break
instruction = item.get("instruction")
response = item.get("response")
if not instruction or not response:
logger.warning(f"Skipping item with missing instruction or response: {item}")
continue
page_content = f"User query: {instruction}\n\nChatbot response: {response}"
metadata = {
"source": "huggingface",
"intent": item.get("intent"),
"category": item.get("category"),
"tags": item.get("tags"),
}
documents.append(Document(page_content=page_content, metadata=metadata))
# 2. Load Local Travel Guides
logger.info("Loading local travel guides...")
if os.path.exists(self.travel_guides_path) and self._validate_file_permissions(self.travel_guides_path):
with open(self.travel_guides_path, "r", encoding="utf-8") as f:
guides = json.load(f)
for guide in guides:
if not all(k in guide for k in ["title", "content", "category"]):
logger.warning(f"Skipping malformed guide: {guide}")
continue
doc = Document(
page_content=guide["content"],
metadata={
"title": guide["title"],
"category": guide["category"],
"source": "travel_guide",
},
)
documents.append(doc)
else:
logger.info("Travel guides file not found or invalid. Skipping.")
logger.info(f"Loaded {len(documents)} documents in total.")
return documents
except Exception as e:
logger.error(f"A critical error occurred while loading documents: {e}", exc_info=True)
return []
def create_vector_store(self, documents: List[Document]):
"""Create a FAISS vector store from documents."""
try:
logger.info("Creating vector store...")
embeddings = HuggingFaceEmbeddings(
model_name=settings.EMBEDDING_MODEL_NAME,
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True},
)
split_docs = self.text_splitter.split_documents(documents)
vector_store = FAISS.from_documents(
documents=split_docs,
embedding=embeddings,
)
vector_store.save_local(self.vector_store_path)
logger.info(f"Vector store created and saved to {self.vector_store_path} with {len(split_docs)} chunks.")
except Exception as e:
logger.error(f"Error creating vector store: {e}", exc_info=True)
raise
def initialize_knowledge_base(self):
"""Initialize the complete knowledge base."""
try:
logger.info("Initializing knowledge base...")
documents = self.load_documents()
if not documents:
logger.error("No documents were loaded. Aborting knowledge base initialization.")
return
self.create_vector_store(documents)
logger.info("Knowledge base initialized successfully.")
except Exception as e:
logger.critical(f"Failed to initialize knowledge base: {e}", exc_info=True)
raise