File size: 2,921 Bytes
9ad5068
dc8ad78
 
55feaa9
dc8ad78
9ad5068
d6def4e
d219d08
33384e5
c348149
fdb23ad
 
c348149
 
 
 
 
5cd38f5
33384e5
dc8ad78
 
 
d6def4e
dc8ad78
d219d08
 
 
 
 
 
 
 
 
 
 
 
 
d6def4e
d219d08
 
 
d6def4e
 
dc8ad78
9ad5068
dc8ad78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ad5068
 
dc8ad78
9ad5068
 
d6def4e
9ad5068
 
dc8ad78
9ad5068
 
 
 
 
 
 
dc8ad78
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import requests
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from transformers import AutoTokenizer, AutoModelForCausalLM
from llama_index.llms.llama_cpp import LlamaCPP

import os
from huggingface_hub import login

# Get the Hugging Face API token from the environment
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

# Log in to Hugging Face using the token
login(token=hf_token)


# ==== 1. Set up Hugging Face Embedding Model ====
# Use HuggingFaceEmbedding from llama_index directly
Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")

# ==== 2. Load Hugging Face LLM (Locally Installed or Remote Hosted) ====
#llm = HuggingFaceLLM(
#    model_name="unsloth/Mistral-Small-3.1-24B-Instruct-2503-GGUF",  # You must have access!
#    tokenizer_name="unsloth/Mistral-Small-3.1-24B-Instruct-2503-GGUF",
#    context_window=2048,
#    max_new_tokens=512,
#    generate_kwargs={"temperature": 0.1},
#    tokenizer_kwargs={"padding_side": "left"},
#    device_map="auto"  # Automatically assign model layers to available devices
#
#)
llm = LlamaCPP(
    model_path="/path/to/your/model.gguf",
    temperature=0.1,
    max_new_tokens=512,
    context_window=4096,
    generate_kwargs={"stop": ["</s>"]},
    model_kwargs={"n_threads": 4},  # adjust for your CPU
)

Settings.llm = llm  # Apply to global settings

# ==== 3. Validate & Download ArXiv PDFs (if needed) ====
def download_pdf(arxiv_id, save_dir="kb"):
    url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
    response = requests.get(url)

    if "application/pdf" in response.headers.get("Content-Type", ""):
        os.makedirs(save_dir, exist_ok=True)
        file_path = os.path.join(save_dir, f"{arxiv_id}.pdf")
        with open(file_path, "wb") as f:
            f.write(response.content)
        print(f"✅ Downloaded {file_path}")
    else:
        print(f"❌ Failed to download PDF for {arxiv_id}: Not a valid PDF")

# Example: download_pdf("2312.03840")

# ==== 4. Load Knowledge Base ====
documents = SimpleDirectoryReader("kb", required_exts=[".pdf"]).load_data()
index = VectorStoreIndex.from_documents(documents)

# ==== 5. Create Query Engine ====
query_engine = index.as_query_engine()

# ==== 6. Wrap as a Tool ====
rag_tool = QueryEngineTool(
    query_engine=query_engine,
    metadata=ToolMetadata(name="RAGSearch", description="Answers from a local HF-based RAG system.")
)

# ==== 7. Basic Agent ====
class BasicAgent:
    def __init__(self):
        self.tool = rag_tool

    def __call__(self, question: str) -> str:
        print(f"🧠 RAG Agent received: {question}")
        response = self.tool.query_engine.query(question)
        return str(response)