|
"""LangGraph Agent with CSV-based Vector Store""" |
|
import os |
|
import ast |
|
import pandas as pd |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from dotenv import load_dotenv |
|
from langgraph.graph import START, StateGraph, MessagesState |
|
from langgraph.prebuilt import tools_condition, ToolNode |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_groq import ChatGroq |
|
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader |
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
from langchain_core.tools import tool |
|
|
|
load_dotenv() |
|
|
|
|
|
@tool |
|
def multiply(a: int, b: int) -> int: |
|
"""Multiply two numbers.""" |
|
return a * b |
|
|
|
@tool |
|
def add(a: int, b: int) -> int: |
|
"""Add two numbers.""" |
|
return a + b |
|
|
|
@tool |
|
def subtract(a: int, b: int) -> int: |
|
"""Subtract two numbers.""" |
|
return a - b |
|
|
|
@tool |
|
def divide(a: int, b: int) -> float: |
|
"""Divide two numbers.""" |
|
if b == 0: |
|
raise ValueError("Cannot divide by zero.") |
|
return a / b |
|
|
|
@tool |
|
def modulus(a: int, b: int) -> int: |
|
"""Get the modulus of two numbers.""" |
|
return a % b |
|
|
|
|
|
@tool |
|
def wiki_search(query: str) -> str: |
|
"""Search Wikipedia for a query and return maximum 2 results.""" |
|
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() |
|
formatted_search_docs = "\n\n---\n\n".join( |
|
[f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
|
for doc in search_docs]) |
|
return formatted_search_docs |
|
|
|
@tool |
|
def web_search(query: str) -> str: |
|
"""Search Tavily for a query and return maximum 3 results.""" |
|
search_docs = TavilySearchResults(max_results=3).invoke(query=query) |
|
formatted_search_docs = "\n\n---\n\n".join( |
|
[f'<Document source="{doc.get("url", "")}" title="{doc.get("title", "")}"/>\n{doc.get("content", "")}\n</Document>' |
|
for doc in search_docs]) |
|
return formatted_search_docs |
|
|
|
@tool |
|
def arxiv_search(query: str) -> str: |
|
"""Search Arxiv for a query and return maximum 3 results.""" |
|
search_docs = ArxivLoader(query=query, load_max_docs=3).load() |
|
formatted_search_docs = "\n\n---\n\n".join( |
|
[f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' |
|
for doc in search_docs]) |
|
return formatted_search_docs |
|
|
|
|
|
class CSVVectorStore: |
|
def __init__(self, csv_file_path: str): |
|
"""Initialize the CSV vector store.""" |
|
self.df = pd.read_csv(csv_file_path) |
|
|
|
self.df['embedding'] = self.df['embedding'].apply(ast.literal_eval) |
|
self.embeddings_matrix = np.array(self.df['embedding'].tolist()) |
|
|
|
def similarity_search(self, query_embedding: np.ndarray, k: int = 1): |
|
"""Find most similar documents to the query embedding.""" |
|
|
|
similarities = cosine_similarity([query_embedding], self.embeddings_matrix)[0] |
|
|
|
|
|
top_indices = np.argsort(similarities)[-k:][::-1] |
|
|
|
|
|
results = [] |
|
for idx in top_indices: |
|
class Document: |
|
def __init__(self, page_content, metadata): |
|
self.page_content = page_content |
|
self.metadata = metadata |
|
|
|
doc = Document( |
|
page_content=self.df.iloc[idx]['content'], |
|
metadata=ast.literal_eval(self.df.iloc[idx]['metadata']) if isinstance(self.df.iloc[idx]['metadata'], str) else self.df.iloc[idx]['metadata'] |
|
) |
|
results.append(doc) |
|
|
|
return results |
|
|
|
|
|
system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools. Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, do not use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, do not use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. Your answer should only start with 'FINAL ANSWER: ', then follows with the answer.""" |
|
|
|
|
|
tools = [ |
|
multiply, |
|
add, |
|
subtract, |
|
divide, |
|
modulus, |
|
wiki_search, |
|
web_search, |
|
arxiv_search, |
|
] |
|
|
|
from langgraph.graph import StateGraph |
|
from langgraph.graph.message import MessagesState |
|
|
|
def build_graph(csv_file_path: str = "embedding_database.csv"): |
|
"""Build the LangGraph with CSV-based vector store and tools.""" |
|
|
|
|
|
vector_store = CSVVectorStore(csv_file_path) |
|
|
|
|
|
sys_msg = SystemMessage(content=system_prompt) |
|
|
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
llm = HuggingFaceEndpoint( |
|
endpoint_url="https://api.endpoints.huggingface.co/v1/completions", |
|
huggingfacehub_api_token="inference", |
|
model_kwargs={"max_tokens": 512} |
|
) |
|
|
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
def retrieve_docs(state: MessagesState) -> MessagesState: |
|
last_human_msg = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)][-1] |
|
query = last_human_msg.content |
|
query_embedding = get_query_embedding(query) |
|
docs = vector_store.similarity_search(query_embedding, k=2) |
|
|
|
content_blocks = [ |
|
f"<Document metadata={doc.metadata}>\n{doc.page_content}\n</Document>" |
|
for doc in docs |
|
] |
|
combined_doc_text = "\n\n---\n\n".join(content_blocks) |
|
|
|
system_prefix = SystemMessage(content=combined_doc_text) |
|
return {"messages": [system_prefix] + state["messages"]} |
|
|
|
|
|
def call_llm(state: MessagesState) -> MessagesState: |
|
messages = [sys_msg] + state["messages"] |
|
response = llm_with_tools.invoke(messages) |
|
return {"messages": state["messages"] + [response]} |
|
|
|
|
|
graph = StateGraph(MessagesState) |
|
graph.add_node("retrieve_docs", retrieve_docs) |
|
graph.add_node("llm", call_llm) |
|
graph.set_entry_point("retrieve_docs") |
|
graph.add_edge("retrieve_docs", "llm") |
|
graph.set_finish_point("llm") |
|
|
|
return graph.compile() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" |
|
|
|
|
|
graph = build_graph(csv_file_path="embedding_database.csv") |
|
|
|
|
|
messages = [HumanMessage(content=question)] |
|
messages = graph.invoke({"messages": messages}) |
|
|
|
for m in messages["messages"]: |
|
m.pretty_print() |