|
import os |
|
import re |
|
from datetime import datetime, timedelta |
|
from typing import TypedDict, Annotated |
|
import sympy as sp |
|
from sympy import * |
|
import math |
|
from langchain_openai import ChatOpenAI |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_core.messages import HumanMessage, SystemMessage |
|
from langgraph.graph import StateGraph, MessagesState, START, END |
|
from langgraph.prebuilt import ToolNode |
|
from langgraph.checkpoint.memory import MemorySaver |
|
import json |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
def read_system_prompt(): |
|
"""Read the system prompt from file""" |
|
try: |
|
with open('system_prompt.txt', 'r') as f: |
|
return f.read().strip() |
|
except FileNotFoundError: |
|
return """You are a helpful assistant tasked with answering questions using a set of tools. |
|
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: |
|
FINAL ANSWER: [YOUR FINAL ANSWER]. |
|
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. |
|
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.""" |
|
|
|
def math_calculator(expression: str) -> str: |
|
""" |
|
Advanced mathematical calculator that can handle complex expressions, |
|
equations, symbolic math, calculus, and more using SymPy. |
|
""" |
|
try: |
|
|
|
expression = expression.strip() |
|
|
|
|
|
expression = expression.replace('^', '**') |
|
expression = expression.replace('ln', 'log') |
|
|
|
|
|
try: |
|
result = sp.sympify(expression) |
|
|
|
|
|
simplified = sp.simplify(result) |
|
|
|
|
|
try: |
|
numerical = float(simplified.evalf()) |
|
return str(numerical) |
|
except: |
|
return str(simplified) |
|
|
|
except: |
|
|
|
|
|
safe_expression = expression |
|
for func in ['sin', 'cos', 'tan', 'sqrt', 'log', 'exp', 'abs']: |
|
safe_expression = safe_expression.replace(func, f'math.{func}') |
|
|
|
|
|
result = eval(safe_expression, {"__builtins__": {}}, { |
|
"math": math, |
|
"pi": math.pi, |
|
"e": math.e |
|
}) |
|
return str(result) |
|
|
|
except Exception as e: |
|
return f"Error calculating '{expression}': {str(e)}" |
|
|
|
def date_time_processor(query: str) -> str: |
|
""" |
|
Process date and time related queries, calculations, and conversions. |
|
""" |
|
try: |
|
current_time = datetime.now() |
|
query_lower = query.lower() |
|
|
|
|
|
if 'current' in query_lower or 'today' in query_lower or 'now' in query_lower: |
|
if 'date' in query_lower: |
|
return current_time.strftime('%Y-%m-%d') |
|
elif 'time' in query_lower: |
|
return current_time.strftime('%H:%M:%S') |
|
else: |
|
return current_time.strftime('%Y-%m-%d %H:%M:%S') |
|
|
|
|
|
if 'day of week' in query_lower or 'what day' in query_lower: |
|
return current_time.strftime('%A') |
|
|
|
|
|
if 'year' in query_lower and 'current' in query_lower: |
|
return str(current_time.year) |
|
|
|
|
|
if 'month' in query_lower and 'current' in query_lower: |
|
return current_time.strftime('%B') |
|
|
|
|
|
if 'days ago' in query_lower: |
|
days_match = re.search(r'(\d+)\s+days?\s+ago', query_lower) |
|
if days_match: |
|
days = int(days_match.group(1)) |
|
past_date = current_time - timedelta(days=days) |
|
return past_date.strftime('%Y-%m-%d') |
|
|
|
if 'days from now' in query_lower or 'days later' in query_lower: |
|
days_match = re.search(r'(\d+)\s+days?\s+(?:from now|later)', query_lower) |
|
if days_match: |
|
days = int(days_match.group(1)) |
|
future_date = current_time + timedelta(days=days) |
|
return future_date.strftime('%Y-%m-%d') |
|
|
|
|
|
return f"Current date and time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}" |
|
|
|
except Exception as e: |
|
return f"Error processing date/time query: {str(e)}" |
|
|
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list, "The messages in the conversation"] |
|
|
|
class GAIAAgent: |
|
def __init__(self): |
|
|
|
openai_key = os.getenv("OPENAI_API_KEY") |
|
tavily_key = os.getenv("TAVILY_API_KEY") |
|
|
|
if not openai_key: |
|
raise ValueError("OPENAI_API_KEY environment variable is required") |
|
if not tavily_key: |
|
raise ValueError("TAVILY_API_KEY environment variable is required") |
|
|
|
print("✅ API keys found - initializing agent...") |
|
|
|
|
|
self.llm = ChatOpenAI( |
|
model="gpt-4o-mini", |
|
temperature=0, |
|
openai_api_key=openai_key |
|
) |
|
|
|
|
|
self.search_tool = TavilySearchResults( |
|
max_results=5, |
|
tavily_api_key=tavily_key |
|
) |
|
|
|
|
|
self.tools = [self.search_tool] |
|
|
|
|
|
self.llm_with_tools = self.llm.bind_tools(self.tools) |
|
|
|
|
|
self.graph = self._build_graph() |
|
|
|
self.system_prompt = read_system_prompt() |
|
|
|
def _build_graph(self): |
|
"""Build the LangGraph workflow""" |
|
|
|
def agent_node(state: AgentState): |
|
"""Main agent reasoning node""" |
|
messages = state["messages"] |
|
|
|
|
|
if not any(isinstance(msg, SystemMessage) for msg in messages): |
|
system_msg = SystemMessage(content=self.system_prompt) |
|
messages = [system_msg] + messages |
|
|
|
|
|
last_human_msg = None |
|
for msg in reversed(messages): |
|
if isinstance(msg, HumanMessage): |
|
last_human_msg = msg.content |
|
break |
|
|
|
|
|
if last_human_msg and self._is_math_problem(last_human_msg): |
|
math_result = math_calculator(last_human_msg) |
|
enhanced_msg = f"Math calculation result: {math_result}\n\nOriginal question: {last_human_msg}\n\nProvide your final answer based on this calculation." |
|
messages[-1] = HumanMessage(content=enhanced_msg) |
|
|
|
|
|
elif last_human_msg and self._is_datetime_problem(last_human_msg): |
|
datetime_result = date_time_processor(last_human_msg) |
|
enhanced_msg = f"Date/time processing result: {datetime_result}\n\nOriginal question: {last_human_msg}\n\nProvide your final answer based on this information." |
|
messages[-1] = HumanMessage(content=enhanced_msg) |
|
|
|
response = self.llm_with_tools.invoke(messages) |
|
return {"messages": messages + [response]} |
|
|
|
def tool_node(state: AgentState): |
|
"""Tool execution node""" |
|
messages = state["messages"] |
|
last_message = messages[-1] |
|
|
|
|
|
tool_node_instance = ToolNode(self.tools) |
|
result = tool_node_instance.invoke(state) |
|
return result |
|
|
|
def should_continue(state: AgentState): |
|
"""Decide whether to continue or end""" |
|
last_message = state["messages"][-1] |
|
|
|
|
|
if hasattr(last_message, 'tool_calls') and last_message.tool_calls: |
|
return "tools" |
|
|
|
|
|
if hasattr(last_message, 'content') and "FINAL ANSWER:" in last_message.content: |
|
return "end" |
|
|
|
|
|
return "end" |
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("agent", agent_node) |
|
workflow.add_node("tools", tool_node) |
|
|
|
|
|
workflow.add_edge(START, "agent") |
|
workflow.add_conditional_edges("agent", should_continue, { |
|
"tools": "tools", |
|
"end": END |
|
}) |
|
workflow.add_edge("tools", "agent") |
|
|
|
|
|
memory = MemorySaver() |
|
return workflow.compile(checkpointer=memory) |
|
|
|
def _is_math_problem(self, text: str) -> bool: |
|
"""Check if the text contains mathematical expressions""" |
|
math_indicators = [ |
|
'+', '-', '*', '/', '^', '=', 'calculate', 'compute', |
|
'solve', 'equation', 'integral', 'derivative', 'sum', |
|
'sqrt', 'log', 'sin', 'cos', 'tan', 'exp' |
|
] |
|
text_lower = text.lower() |
|
return any(indicator in text_lower for indicator in math_indicators) or \ |
|
re.search(r'\d+[\+\-\*/\^]\d+', text) is not None |
|
|
|
def _is_datetime_problem(self, text: str) -> bool: |
|
"""Check if the text contains date/time related queries""" |
|
datetime_indicators = [ |
|
'date', 'time', 'day', 'month', 'year', 'today', 'yesterday', |
|
'tomorrow', 'current', 'now', 'ago', 'later', 'when' |
|
] |
|
text_lower = text.lower() |
|
return any(indicator in text_lower for indicator in datetime_indicators) |
|
|
|
def __call__(self, question: str) -> str: |
|
"""Process a question and return the answer""" |
|
try: |
|
print(f"Processing question: {question[:100]}...") |
|
|
|
|
|
initial_state = { |
|
"messages": [HumanMessage(content=question)] |
|
} |
|
|
|
|
|
config = {"configurable": {"thread_id": "gaia_thread"}} |
|
final_state = self.graph.invoke(initial_state, config) |
|
|
|
|
|
last_message = final_state["messages"][-1] |
|
response_content = last_message.content if hasattr(last_message, 'content') else str(last_message) |
|
|
|
|
|
final_answer = self._extract_final_answer(response_content) |
|
|
|
print(f"Final answer: {final_answer}") |
|
return final_answer |
|
|
|
except Exception as e: |
|
print(f"Error processing question: {e}") |
|
return f"Error: {str(e)}" |
|
|
|
def _extract_final_answer(self, response: str) -> str: |
|
"""Extract the final answer from the response""" |
|
if "FINAL ANSWER:" in response: |
|
|
|
parts = response.split("FINAL ANSWER:") |
|
if len(parts) > 1: |
|
answer = parts[-1].strip() |
|
|
|
answer = answer.split('\n')[0].strip() |
|
return answer |
|
|
|
|
|
return response.strip() |
|
|
|
|
|
def create_agent(): |
|
"""Factory function to create the GAIA agent""" |
|
return GAIAAgent() |