Spaces:
Sleeping
Sleeping
import os | |
import json | |
import logging | |
import stat | |
import time | |
from typing import Any, List | |
from config.config import settings | |
from datasets import load_dataset | |
from langchain.schema import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
logger = logging.getLogger(__name__) | |
class DataLoader: | |
"""Handles loading and processing of data for the RAG engine.""" | |
def __init__(self): | |
"""Initialize the data loader.""" | |
self.data_dir = os.path.abspath("data") | |
self.travel_guides_path = os.path.join(self.data_dir, "travel_guides.json") | |
self.vector_store_path = os.path.join(self.data_dir, "vector_store", "faiss_index") | |
self._ensure_data_directories() | |
self._set_directory_permissions() | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=settings.CHUNK_SIZE, | |
chunk_overlap=settings.CHUNK_OVERLAP, | |
length_function=len, | |
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""], | |
) | |
self.max_file_size = 10 * 1024 * 1024 # 10MB | |
def _ensure_data_directories(self): | |
"""Ensure necessary data directories exist.""" | |
os.makedirs(self.data_dir, exist_ok=True) | |
os.makedirs(os.path.dirname(self.vector_store_path), exist_ok=True) | |
os.makedirs(os.path.join(self.data_dir, "cache"), exist_ok=True) | |
def _set_directory_permissions(self): | |
"""Set secure permissions for data directories (755).""" | |
try: | |
for dir_path in [ | |
self.data_dir, | |
os.path.dirname(self.vector_store_path), | |
os.path.join(self.data_dir, "cache"), | |
]: | |
os.chmod( | |
dir_path, | |
stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH, | |
) | |
except Exception as e: | |
logger.error(f"Error setting directory permissions: {e}", exc_info=True) | |
def _validate_file_permissions(self, file_path: str) -> bool: | |
"""Validate file permissions to ensure security.""" | |
try: | |
if not os.path.exists(file_path): | |
return False | |
file_stat = os.stat(file_path) | |
if file_stat.st_mode & stat.S_IWOTH: # Disallow world-writable | |
logger.warning(f"File {file_path} is world-writable. Skipping.") | |
return False | |
if file_stat.st_size > self.max_file_size: | |
logger.warning(f"File {file_path} exceeds size limit. Skipping.") | |
return False | |
return True | |
except Exception as e: | |
logger.error(f"Error validating file permissions for {file_path}: {e}", exc_info=True) | |
return False | |
def _load_dataset_with_retry(self, max_retries: int = 3) -> Any: | |
"""Load dataset from Hugging Face with an exponential backoff retry mechanism.""" | |
for attempt in range(max_retries): | |
try: | |
return load_dataset( | |
settings.DATASET_ID, | |
split="train", | |
cache_dir=os.path.join(self.data_dir, "cache"), | |
) | |
except Exception as e: | |
logger.warning(f"Dataset loading attempt {attempt + 1} failed: {e}") | |
if attempt == max_retries - 1: | |
logger.error("All attempts to load dataset failed.") | |
return None | |
time.sleep(2 ** attempt) | |
return None | |
def load_documents(self) -> List[Document]: | |
"""Load and process all documents for the knowledge base.""" | |
documents = [] | |
try: | |
# 1. Load Bitext Travel Dataset | |
logger.info(f"Loading dataset: {settings.DATASET_ID}") | |
dataset = self._load_dataset_with_retry() | |
if dataset: | |
max_docs = settings.MAX_DOCUMENTS_TO_LOAD | |
logger.info(f"Loading up to {max_docs} documents from the dataset.") | |
for i, item in enumerate(dataset): | |
if i >= max_docs: | |
logger.info(f"Reached document limit ({max_docs}).") | |
break | |
instruction = item.get("instruction") | |
response = item.get("response") | |
if not instruction or not response: | |
logger.warning(f"Skipping item with missing instruction or response: {item}") | |
continue | |
page_content = f"User query: {instruction}\n\nChatbot response: {response}" | |
metadata = { | |
"source": "huggingface", | |
"intent": item.get("intent"), | |
"category": item.get("category"), | |
"tags": item.get("tags"), | |
} | |
documents.append(Document(page_content=page_content, metadata=metadata)) | |
# 2. Load Local Travel Guides | |
logger.info("Loading local travel guides...") | |
if os.path.exists(self.travel_guides_path) and self._validate_file_permissions(self.travel_guides_path): | |
with open(self.travel_guides_path, "r", encoding="utf-8") as f: | |
guides = json.load(f) | |
for guide in guides: | |
if not all(k in guide for k in ["title", "content", "category"]): | |
logger.warning(f"Skipping malformed guide: {guide}") | |
continue | |
doc = Document( | |
page_content=guide["content"], | |
metadata={ | |
"title": guide["title"], | |
"category": guide["category"], | |
"source": "travel_guide", | |
}, | |
) | |
documents.append(doc) | |
else: | |
logger.info("Travel guides file not found or invalid. Skipping.") | |
logger.info(f"Loaded {len(documents)} documents in total.") | |
return documents | |
except Exception as e: | |
logger.error(f"A critical error occurred while loading documents: {e}", exc_info=True) | |
return [] | |
def create_vector_store(self, documents: List[Document]): | |
"""Create a FAISS vector store from documents.""" | |
try: | |
logger.info("Creating vector store...") | |
embeddings = HuggingFaceEmbeddings( | |
model_name=settings.EMBEDDING_MODEL_NAME, | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True}, | |
) | |
split_docs = self.text_splitter.split_documents(documents) | |
vector_store = FAISS.from_documents( | |
documents=split_docs, | |
embedding=embeddings, | |
) | |
vector_store.save_local(self.vector_store_path) | |
logger.info(f"Vector store created and saved to {self.vector_store_path} with {len(split_docs)} chunks.") | |
except Exception as e: | |
logger.error(f"Error creating vector store: {e}", exc_info=True) | |
raise | |
def initialize_knowledge_base(self): | |
"""Initialize the complete knowledge base.""" | |
try: | |
logger.info("Initializing knowledge base...") | |
documents = self.load_documents() | |
if not documents: | |
logger.error("No documents were loaded. Aborting knowledge base initialization.") | |
return | |
self.create_vector_store(documents) | |
logger.info("Knowledge base initialized successfully.") | |
except Exception as e: | |
logger.critical(f"Failed to initialize knowledge base: {e}", exc_info=True) | |
raise | |