Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
from io import BytesIO | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from langchain_core.tools import tool | |
from langchain_core.utils.function_calling import convert_to_openai_function | |
from langchain_openai import ChatOpenAI | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_experimental.utilities import PythonREPL | |
from langgraph.graph import StateGraph, END | |
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation | |
# Set up environment variables for API keys | |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
# Validate API keys | |
if not TAVILY_API_KEY or not OPENAI_API_KEY: | |
st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.") | |
st.stop() | |
# Initialize tools | |
tavily_tool = TavilySearchResults(max_results=5) | |
def python_repl(code: str): | |
"""Executes Python code to generate a chart and returns the chart as a base64-encoded image.""" | |
try: | |
# Execute the provided Python code | |
exec_globals = {"plt": plt} | |
exec_locals = {} | |
exec(code, exec_globals, exec_locals) | |
# Save the generated plot to a buffer | |
buf = BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
# Clear the plot to avoid overlapping | |
plt.clf() | |
plt.close() | |
# Encode image as base64 | |
encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8") | |
return {"status": "success", "image": encoded_image} | |
except Exception as e: | |
return {"status": "failed", "error": str(e)} | |
tools = [tavily_tool, python_repl] | |
tool_executor = ToolExecutor(tools) | |
# Define the multi-agent workflow | |
workflow = StateGraph() | |
workflow.add_node("call_tool", lambda state: tool_executor.invoke(ToolInvocation(**state))) | |
workflow.set_entry_point("call_tool") | |
# Streamlit UI | |
st.title("Multi-Agent Workflow") | |
st.markdown("### Generate a Chart from Python Code") | |
code_input = st.text_area( | |
"Enter Python code for the chart:", | |
""" | |
import matplotlib.pyplot as plt | |
years = [2019, 2020, 2021, 2022, 2023] | |
gdp = [300, 310, 330, 360, 399] | |
plt.figure(figsize=(10, 6)) | |
plt.plot(years, gdp, marker='o', color='b', linestyle='-') | |
plt.title('Malaysia GDP Over 5 Years') | |
plt.xlabel('Year') | |
plt.ylabel('GDP (in billion USD)') | |
plt.grid(True) | |
""" | |
) | |
if st.button("Generate Chart"): | |
st.write("Generating chart...") | |
with st.spinner("Processing..."): | |
# Invoke the Python REPL tool with the code | |
response = python_repl(code_input) | |
# Check response status | |
if response["status"] == "success" and "image" in response: | |
encoded_image = response["image"] | |
try: | |
# Decode the base64-encoded image | |
decoded_image = BytesIO(base64.b64decode(encoded_image)) | |
# Display the image | |
st.image(decoded_image, caption="Generated Chart", use_column_width=True) | |
except Exception as e: | |
st.error(f"Failed to decode and display the image: {str(e)}") | |
else: | |
st.error(f"Failed to generate chart: {response.get('error', 'Unknown error')}") | |
st.markdown("### Example Queries") | |
st.write("1. Plot GDP of Malaysia over 5 years") | |
st.write("2. Create a bar chart of sales data") | |