jarguello76's picture
Update tools.py
926c41e verified
raw
history blame
7.59 kB
"""LangGraph Agent with CSV-based Vector Store"""
import os
import ast
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
load_dotenv()
# Math tools
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b
@tool
def divide(a: int, b: int) -> float:
"""Divide two numbers."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers."""
return a % b
# Search tools
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results."""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs])
return formatted_search_docs
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results."""
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
formatted_search_docs = "\n\n---\n\n".join(
[f'<Document source="{doc.get("url", "")}" title="{doc.get("title", "")}"/>\n{doc.get("content", "")}\n</Document>'
for doc in search_docs])
return formatted_search_docs
@tool
def arxiv_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 results."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs])
return formatted_search_docs
# CSV-based Vector Store Class
class CSVVectorStore:
def __init__(self, csv_file_path: str):
"""Initialize the CSV vector store."""
self.df = pd.read_csv(csv_file_path)
# Convert string representation of embeddings to numpy arrays
self.df['embedding'] = self.df['embedding'].apply(ast.literal_eval)
self.embeddings_matrix = np.array(self.df['embedding'].tolist())
def similarity_search(self, query_embedding: np.ndarray, k: int = 1):
"""Find most similar documents to the query embedding."""
# Calculate cosine similarity
similarities = cosine_similarity([query_embedding], self.embeddings_matrix)[0]
# Get top k indices
top_indices = np.argsort(similarities)[-k:][::-1]
# Return results in a format similar to LangChain's Document
results = []
for idx in top_indices:
class Document:
def __init__(self, page_content, metadata):
self.page_content = page_content
self.metadata = metadata
doc = Document(
page_content=self.df.iloc[idx]['content'],
metadata=ast.literal_eval(self.df.iloc[idx]['metadata']) if isinstance(self.df.iloc[idx]['metadata'], str) else self.df.iloc[idx]['metadata']
)
results.append(doc)
return results
# System prompt
system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools. Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, do not use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, do not use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. Your answer should only start with 'FINAL ANSWER: ', then follows with the answer."""
# Tools list
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arxiv_search,
]
from langgraph.graph import StateGraph
from langgraph.graph.message import MessagesState
def build_graph(csv_file_path: str = "embedding_database.csv"):
"""Build the LangGraph with CSV-based vector store and tools."""
# Initialize CSV vector store
vector_store = CSVVectorStore(csv_file_path)
# System message
sys_msg = SystemMessage(content=system_prompt)
from langchain_community.llms import HuggingFaceEndpoint
llm = HuggingFaceEndpoint(
endpoint_url="https://api.endpoints.huggingface.co/v1/completions",
huggingfacehub_api_token="inference",
model_kwargs={"max_tokens": 512}
)
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Function to enrich state with relevant content from vector store
def retrieve_docs(state: MessagesState) -> MessagesState:
last_human_msg = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)][-1]
query = last_human_msg.content
query_embedding = get_query_embedding(query)
docs = vector_store.similarity_search(query_embedding, k=2)
content_blocks = [
f"<Document metadata={doc.metadata}>\n{doc.page_content}\n</Document>"
for doc in docs
]
combined_doc_text = "\n\n---\n\n".join(content_blocks)
system_prefix = SystemMessage(content=combined_doc_text)
return {"messages": [system_prefix] + state["messages"]}
# Node function for assistant
def call_llm(state: MessagesState) -> MessagesState:
messages = [sys_msg] + state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": state["messages"] + [response]}
# Construct LangGraph
graph = StateGraph(MessagesState)
graph.add_node("retrieve_docs", retrieve_docs)
graph.add_node("llm", call_llm)
graph.set_entry_point("retrieve_docs")
graph.add_edge("retrieve_docs", "llm")
graph.set_finish_point("llm")
return graph.compile()
# Test
if __name__ == "__main__":
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
# Build the graph (you'll need to provide the path to your CSV file)
graph = build_graph(csv_file_path="embedding_database.csv")
# Run the graph
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()