SamerPF commited on
Commit
d6def4e
·
verified ·
1 Parent(s): ecf22a7

Update rag_agent.py

Browse files
Files changed (1) hide show
  1. rag_agent.py +21 -7
rag_agent.py CHANGED
@@ -1,24 +1,38 @@
1
  import os
2
  from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
3
- from llama_index.llms.openai import OpenAI
4
  from llama_index.core.tools import QueryEngineTool, ToolMetadata
 
5
 
6
- # Build Service Context with preferred model
7
- llm = OpenAI(model="gpt-3.5-turbo", temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  service_context = ServiceContext.from_defaults(llm=llm)
9
 
10
- # Load and index documents
11
  documents = SimpleDirectoryReader("kb").load_data()
12
  index = VectorStoreIndex.from_documents(documents, service_context=service_context)
13
  query_engine = index.as_query_engine()
14
 
15
- # Tool wrapper to integrate with a multi-tool agent if needed
16
  rag_tool = QueryEngineTool(
17
  query_engine=query_engine,
18
- metadata=ToolMetadata(name="RAGSearch", description="Answers questions using a local knowledge base")
19
  )
20
 
21
- # Agent class (Hugging Face-compatible)
22
  class BasicAgent:
23
  def __init__(self):
24
  self.tool = rag_tool
 
1
  import os
2
  from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
3
+ from llama_index.llms.huggingface import HuggingFaceLLM
4
  from llama_index.core.tools import QueryEngineTool, ToolMetadata
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
+ # Load local or remote HF model (example: mistral hosted model)
8
+ model_name = "mistralai/Mistral-7B-Instruct-v0.2"
9
+
10
+ # Create HuggingFaceLLM
11
+ llm = HuggingFaceLLM(
12
+ model_name=model_name,
13
+ tokenizer_name=model_name,
14
+ context_window=2048,
15
+ max_new_tokens=512,
16
+ generate_kwargs={"temperature": 0.1},
17
+ tokenizer_kwargs={"padding_side": "left"},
18
+ device_map="auto",
19
+ )
20
+
21
+ # Build service context
22
  service_context = ServiceContext.from_defaults(llm=llm)
23
 
24
+ # Load knowledge base
25
  documents = SimpleDirectoryReader("kb").load_data()
26
  index = VectorStoreIndex.from_documents(documents, service_context=service_context)
27
  query_engine = index.as_query_engine()
28
 
29
+ # Wrap in a tool
30
  rag_tool = QueryEngineTool(
31
  query_engine=query_engine,
32
+ metadata=ToolMetadata(name="RAGSearch", description="Answers from a local HF-based RAG system.")
33
  )
34
 
35
+ # Agent definition
36
  class BasicAgent:
37
  def __init__(self):
38
  self.tool = rag_tool