mistral-rag / retriever.py
radussad's picture
Create retriever.py
af67408 verified
raw
history blame
1.32 kB
import chromadb
from sentence_transformers import SentenceTransformer
# Initialize ChromaDB
chroma_client = chromadb.PersistentClient(path="data/chroma_db")
# Load embedding model
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
# Create collection (vector storage)
collection = chroma_client.get_or_create_collection(name="documents")
def add_document(text: str):
"""Add a document to the vector database."""
embedding = embed_model.encode(text).tolist()
collection.add(ids=[text[:10]], embeddings=[embedding], documents=[text])
def retrieve_documents(query: str, top_k=3):
"""Retrieve most relevant documents for a given query."""
query_embedding = embed_model.encode(query).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
return " ".join(results["documents"]) if results["documents"] else "No relevant documents found."
# Add some example data (You can replace this with actual text files)
if collection.count() == 0:
add_document("Mistral 7B is a powerful AI model designed for language generation.")
add_document("RAG stands for Retrieval-Augmented Generation, enhancing AI models with knowledge retrieval.")
add_document("ChromaDB is a vector database optimized for embeddings and search.")