Spaces:
Sleeping
Sleeping
import os | |
import json | |
import re | |
import base64 | |
import streamlit as st | |
from io import BytesIO | |
from langchain_core.utils.function_calling import convert_to_openai_function | |
from langchain_core.messages import ( | |
AIMessage, | |
BaseMessage, | |
ChatMessage, | |
FunctionMessage, | |
HumanMessage, | |
) | |
from langchain.tools.render import format_tool_to_openai_function | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation | |
from langchain_core.tools import tool | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_experimental.utilities import PythonREPL | |
from langchain_openai import ChatOpenAI | |
from typing import Annotated, Sequence | |
from typing_extensions import TypedDict | |
import operator | |
import functools | |
import matplotlib.pyplot as plt | |
# Set up environment variables for API keys | |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
# Validate API keys | |
if not TAVILY_API_KEY or not OPENAI_API_KEY: | |
st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.") | |
st.stop() | |
# Define the AgentState class | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
sender: str | |
# Initialize tools | |
tavily_tool = TavilySearchResults(max_results=5) | |
repl = PythonREPL() | |
def python_repl(code: Annotated[str, "The python code to execute to generate your chart."]): | |
"""Executes Python code to generate a chart and returns the chart as a base64-encoded image.""" | |
try: | |
# Execute the code | |
exec_globals = {"plt": plt} | |
exec_locals = {} | |
exec(code, exec_globals, exec_locals) | |
# Save the generated plot to a buffer | |
buf = BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
# Clear the plot to avoid overlapping | |
plt.clf() | |
plt.close() | |
# Encode image as base64 | |
encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8") | |
return {"status": "success", "image": encoded_image} | |
except Exception as e: | |
return {"status": "failed", "error": repr(e)} | |
tools = [tavily_tool, python_repl] | |
# Define a tool executor | |
tool_executor = ToolExecutor(tools) | |
# Define tool node | |
def tool_node(state): | |
"""Executes tools in the graph.""" | |
messages = state["messages"] | |
last_message = messages[-1] | |
tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"]) | |
if len(tool_input) == 1 and "__arg1" in tool_input: | |
tool_input = next(iter(tool_input.values())) | |
tool_name = last_message.additional_kwargs["function_call"]["name"] | |
action = ToolInvocation(tool=tool_name, tool_input=tool_input) | |
response = tool_executor.invoke(action) | |
if isinstance(response, dict) and response.get("status") == "success" and "image" in response: | |
return { | |
"messages": [ | |
{ | |
"role": "assistant", | |
"content": "Image generated successfully.", | |
"image": response["image"], | |
} | |
] | |
} | |
else: | |
function_message = FunctionMessage( | |
content=f"{tool_name} response: {str(response)}", name=action.tool | |
) | |
return {"messages": [function_message]} | |
# Define router | |
def router(state): | |
"""Determines the next step in the workflow.""" | |
messages = state["messages"] | |
last_message = messages[-1] | |
if "function_call" in last_message.additional_kwargs: | |
return "call_tool" | |
if "FINAL ANSWER" in last_message.content: | |
return "end" | |
return "continue" | |
# Define agent creation function | |
def create_agent(llm, tools, system_message: str): | |
"""Creates an agent.""" | |
functions = [convert_to_openai_function(t) for t in tools] | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"You are a helpful AI assistant, collaborating with other assistants." | |
" Use the provided tools to progress towards answering the question." | |
" If you are unable to fully answer, that's OK, another assistant with different tools " | |
" will help where you left off. Execute what you can to make progress." | |
" If you or any of the other assistants have the final answer or deliverable," | |
" prefix your response with FINAL ANSWER so the team knows to stop." | |
" You have access to the following tools: {tool_names}.\n{system_message}", | |
), | |
MessagesPlaceholder(variable_name="messages"), | |
] | |
) | |
prompt = prompt.partial(system_message=system_message) | |
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools])) | |
return prompt | llm.bind_functions(functions) | |
# Define agent node | |
def agent_node(state, agent, name): | |
result = agent.invoke(state) | |
if isinstance(result, FunctionMessage): | |
pass | |
else: | |
# Sanitize the name field to match OpenAI's naming conventions | |
sanitized_name = re.sub(r"[^a-zA-Z0-9_-]", "_", name) | |
result = HumanMessage(**result.dict(exclude={"type", "name"}), name=sanitized_name) | |
return {"messages": [result], "sender": name} | |
# Initialize LLM | |
llm = ChatOpenAI(api_key=OPENAI_API_KEY) | |
# Create agents | |
research_agent = create_agent( | |
llm, [tavily_tool], system_message="You should provide accurate data for the chart generator to use." | |
) | |
chart_agent = create_agent( | |
llm, [python_repl], system_message="Any charts you display will be visible by the user." | |
) | |
# Define workflow graph | |
workflow = StateGraph(AgentState) | |
workflow.add_node("Researcher", functools.partial(agent_node, agent=research_agent, name="Researcher")) | |
workflow.add_node("Chart Generator", functools.partial(agent_node, agent=chart_agent, name="Chart Generator")) | |
workflow.add_node("call_tool", tool_node) | |
workflow.add_conditional_edges("Researcher", router, {"continue": "Chart Generator", "call_tool": "call_tool", "end": END}) | |
workflow.add_conditional_edges("Chart Generator", router, {"continue": "Researcher", "call_tool": "call_tool", "end": END}) | |
workflow.add_conditional_edges("call_tool", lambda x: x["sender"], {"Researcher": "Researcher", "Chart Generator": "Chart Generator"}) | |
workflow.set_entry_point("Researcher") | |
graph = workflow.compile() | |
# Streamlit UI | |
st.title("Multi-Agent Workflow") | |
user_query = st.text_area("Enter your query:", "Fetch Malaysia's GDP over the past 5 years and draw a line graph.") | |
if st.button("Run Workflow"): | |
st.write("Running workflow...") | |
with st.spinner("Processing..."): | |
try: | |
messages = [HumanMessage(content=user_query)] | |
for step in graph.stream({"messages": messages}, {"recursion_limit": 150}): | |
st.write("Step Details:", step) | |
if "messages" in step: | |
for message in step["messages"]: | |
if "image" in message: | |
try: | |
# Decode the base64-encoded image | |
encoded_image = message["image"] | |
decoded_image = BytesIO(base64.b64decode(encoded_image)) | |
# Display the image | |
st.image(decoded_image, caption="Generated Chart", use_column_width=True) | |
except Exception as e: | |
st.error(f"Failed to decode and display the image: {repr(e)}") | |
elif "content" in message: | |
# Display any text content | |
st.write(message["content"]) | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |