import os import base64 from io import BytesIO import streamlit as st import matplotlib.pyplot as plt from langchain_core.tools import tool from langchain_core.utils.function_calling import convert_to_openai_function from langchain_openai import ChatOpenAI from langchain_community.tools.tavily_search import TavilySearchResults from langchain_experimental.utilities import PythonREPL from langgraph.graph import StateGraph, END from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation # Set up environment variables for API keys TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Validate API keys if not TAVILY_API_KEY or not OPENAI_API_KEY: st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.") st.stop() # Initialize tools tavily_tool = TavilySearchResults(max_results=5) @tool def python_repl(code: str): """Executes Python code to generate a chart and returns the chart as a base64-encoded image.""" try: # Execute the provided Python code exec_globals = {"plt": plt} exec_locals = {} exec(code, exec_globals, exec_locals) # Save the generated plot to a buffer buf = BytesIO() plt.savefig(buf, format="png") buf.seek(0) # Clear the plot to avoid overlapping plt.clf() plt.close() # Encode image as base64 encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8") return {"status": "success", "image": encoded_image} except Exception as e: return {"status": "failed", "error": str(e)} tools = [tavily_tool, python_repl] tool_executor = ToolExecutor(tools) # Define the multi-agent workflow workflow = StateGraph() workflow.add_node("call_tool", lambda state: tool_executor.invoke(ToolInvocation(**state))) workflow.set_entry_point("call_tool") # Streamlit UI st.title("Multi-Agent Workflow") st.markdown("### Generate a Chart from Python Code") code_input = st.text_area( "Enter Python code for the chart:", """ import matplotlib.pyplot as plt years = [2019, 2020, 2021, 2022, 2023] gdp = [300, 310, 330, 360, 399] plt.figure(figsize=(10, 6)) plt.plot(years, gdp, marker='o', color='b', linestyle='-') plt.title('Malaysia GDP Over 5 Years') plt.xlabel('Year') plt.ylabel('GDP (in billion USD)') plt.grid(True) """ ) if st.button("Generate Chart"): st.write("Generating chart...") with st.spinner("Processing..."): # Invoke the Python REPL tool with the code response = python_repl(code_input) # Check response status if response["status"] == "success" and "image" in response: encoded_image = response["image"] try: # Decode the base64-encoded image decoded_image = BytesIO(base64.b64decode(encoded_image)) # Display the image st.image(decoded_image, caption="Generated Chart", use_column_width=True) except Exception as e: st.error(f"Failed to decode and display the image: {str(e)}") else: st.error(f"Failed to generate chart: {response.get('error', 'Unknown error')}") st.markdown("### Example Queries") st.write("1. Plot GDP of Malaysia over 5 years") st.write("2. Create a bar chart of sales data")