|
from ast import main |
|
import os |
|
from typing import TypedDict, List, Dict, Any, Optional |
|
from langgraph.graph import StateGraph, START, END |
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_core.rate_limiters import InMemoryRateLimiter |
|
|
|
|
|
GAIA_PROMPT = "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." |
|
|
|
|
|
|
|
|
|
|
|
class GAIAAgentState(TypedDict): |
|
"""State of the GAIA agent.""" |
|
|
|
task_id: str |
|
question: str |
|
file_id: Optional[str] |
|
answer: Optional[str] |
|
thought: Optional[str] |
|
|
|
|
|
|
|
class BasicAgent: |
|
def __init__(self): |
|
|
|
self.rate_limiter = InMemoryRateLimiter( |
|
requests_per_second=0.2 |
|
) |
|
self.model = ChatGoogleGenerativeAI( |
|
model="gemini-2.0-flash", |
|
temperature=0, |
|
max_tokens=None, |
|
timeout=None, |
|
max_retries=2, |
|
google_api_key=os.environ["GEMINI_API_KEY"], |
|
rate_limiter=self.rate_limiter, |
|
) |
|
print("BasicAgent initialized.") |
|
|
|
def __call__(self, question: str) -> str: |
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
messages = [ |
|
("system", GAIA_PROMPT), |
|
("human", question), |
|
] |
|
|
|
|
|
ai_msg = self.model.invoke(messages) |
|
|
|
|
|
print(f"Agent returning response: {ai_msg.content}") |
|
return ( |
|
str(ai_msg.content) |
|
if not isinstance(ai_msg.content, str) |
|
else ai_msg.content |
|
) |
|
|
|
|
|
class GraphManager: |
|
def __init__(self): |
|
self.graph = StateGraph(GAIAAgentState) |
|
print("GraphManager initialized.") |
|
|
|
def read_question_and_define_gaia_state( |
|
self, state: GAIAAgentState |
|
) -> GAIAAgentState: |
|
pass |
|
|
|
def build_graph(self) -> StateGraph: |
|
|
|
self.graph.add_node( |
|
"read_question_and_define_gaia_state", |
|
self.read_question_and_define_gaia_state, |
|
) |
|
|
|
self.graph.add_edge(START, "read_question_and_define_gaia_state") |
|
return self.graph |
|
|