SamerPF's picture
Update rag_agent.py
d219d08 verified
import os
import requests
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.llms.llama_cpp import LlamaCPP
import os
from huggingface_hub import login
# Get the Hugging Face API token from the environment
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Log in to Hugging Face using the token
login(token=hf_token)
# ==== 1. Set up Hugging Face Embedding Model ====
# Use HuggingFaceEmbedding from llama_index directly
Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
# ==== 2. Load Hugging Face LLM (Locally Installed or Remote Hosted) ====
#llm = HuggingFaceLLM(
# model_name="unsloth/Mistral-Small-3.1-24B-Instruct-2503-GGUF", # You must have access!
# tokenizer_name="unsloth/Mistral-Small-3.1-24B-Instruct-2503-GGUF",
# context_window=2048,
# max_new_tokens=512,
# generate_kwargs={"temperature": 0.1},
# tokenizer_kwargs={"padding_side": "left"},
# device_map="auto" # Automatically assign model layers to available devices
#
#)
llm = LlamaCPP(
model_path="/path/to/your/model.gguf",
temperature=0.1,
max_new_tokens=512,
context_window=4096,
generate_kwargs={"stop": ["</s>"]},
model_kwargs={"n_threads": 4}, # adjust for your CPU
)
Settings.llm = llm # Apply to global settings
# ==== 3. Validate & Download ArXiv PDFs (if needed) ====
def download_pdf(arxiv_id, save_dir="kb"):
url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
response = requests.get(url)
if "application/pdf" in response.headers.get("Content-Type", ""):
os.makedirs(save_dir, exist_ok=True)
file_path = os.path.join(save_dir, f"{arxiv_id}.pdf")
with open(file_path, "wb") as f:
f.write(response.content)
print(f"✅ Downloaded {file_path}")
else:
print(f"❌ Failed to download PDF for {arxiv_id}: Not a valid PDF")
# Example: download_pdf("2312.03840")
# ==== 4. Load Knowledge Base ====
documents = SimpleDirectoryReader("kb", required_exts=[".pdf"]).load_data()
index = VectorStoreIndex.from_documents(documents)
# ==== 5. Create Query Engine ====
query_engine = index.as_query_engine()
# ==== 6. Wrap as a Tool ====
rag_tool = QueryEngineTool(
query_engine=query_engine,
metadata=ToolMetadata(name="RAGSearch", description="Answers from a local HF-based RAG system.")
)
# ==== 7. Basic Agent ====
class BasicAgent:
def __init__(self):
self.tool = rag_tool
def __call__(self, question: str) -> str:
print(f"🧠 RAG Agent received: {question}")
response = self.tool.query_engine.query(question)
return str(response)