mistral-rag / retriever.py
radussad's picture
Update retriever.py
ecd9ac7 verified
import chromadb
from sentence_transformers import SentenceTransformer
import os
# Set Hugging Face cache directory to a writable location
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/tmp/huggingface"
# Ensure /tmp directories exist
#os.makedirs("/tmp/huggingface", exist_ok=True)
#os.makedirs("/tmp/data", exist_ok=True)
# Initialize ChromaDB
# chroma_client = chromadb.PersistentClient(path="data/chroma_db")
chroma_client = chromadb.EphemeralClient() # In-memory mode, resets on restart
# Load embedding model - previously: sentence-transformers/
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# Create collection (vector storage)
collection = chroma_client.get_or_create_collection(name="documents")
def add_document(text: str):
"""Add a document to the vector database."""
embedding = embed_model.encode(text).tolist()
collection.add(ids=[text[:10]], embeddings=[embedding], documents=[text])
def retrieve_documents(query: str, top_k=3):
"""Retrieve most relevant documents for a given query."""
query_embedding = embed_model.encode(query).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
retrieved_texts = []
for doc in results["documents"]:
if isinstance(doc, list):
retrieved_texts.extend(doc) # Unnest lists
else:
retrieved_texts.append(doc)
return " ".join(retrieved_texts) if retrieved_texts else "No relevant documents found."
# return " ".join(results["documents"]) if results["documents"] else "No relevant documents found."
# Add some example data (You can replace this with actual text files)
if collection.count() == 0:
add_document("Mistral 7B is a powerful AI model designed for language generation.")
add_document("RAG stands for Retrieval-Augmented Generation, enhancing AI models with knowledge retrieval.")
add_document("ChromaDB is a vector database optimized for embeddings and search.")