test-space-mas / app.py
DrishtiSharma's picture
Update app.py
db0c72e verified
raw
history blame
3.35 kB
import os
import base64
from io import BytesIO
import streamlit as st
import matplotlib.pyplot as plt
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.utilities import PythonREPL
from langgraph.graph import StateGraph, END
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
# Set up environment variables for API keys
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Validate API keys
if not TAVILY_API_KEY or not OPENAI_API_KEY:
st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.")
st.stop()
# Initialize tools
tavily_tool = TavilySearchResults(max_results=5)
@tool
def python_repl(code: str):
"""Executes Python code to generate a chart and returns the chart as a base64-encoded image."""
try:
# Execute the provided Python code
exec_globals = {"plt": plt}
exec_locals = {}
exec(code, exec_globals, exec_locals)
# Save the generated plot to a buffer
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
# Clear the plot to avoid overlapping
plt.clf()
plt.close()
# Encode image as base64
encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8")
return {"status": "success", "image": encoded_image}
except Exception as e:
return {"status": "failed", "error": str(e)}
tools = [tavily_tool, python_repl]
tool_executor = ToolExecutor(tools)
# Define the multi-agent workflow
workflow = StateGraph()
workflow.add_node("call_tool", lambda state: tool_executor.invoke(ToolInvocation(**state)))
workflow.set_entry_point("call_tool")
# Streamlit UI
st.title("Multi-Agent Workflow")
st.markdown("### Generate a Chart from Python Code")
code_input = st.text_area(
"Enter Python code for the chart:",
"""
import matplotlib.pyplot as plt
years = [2019, 2020, 2021, 2022, 2023]
gdp = [300, 310, 330, 360, 399]
plt.figure(figsize=(10, 6))
plt.plot(years, gdp, marker='o', color='b', linestyle='-')
plt.title('Malaysia GDP Over 5 Years')
plt.xlabel('Year')
plt.ylabel('GDP (in billion USD)')
plt.grid(True)
"""
)
if st.button("Generate Chart"):
st.write("Generating chart...")
with st.spinner("Processing..."):
# Invoke the Python REPL tool with the code
response = python_repl(code_input)
# Check response status
if response["status"] == "success" and "image" in response:
encoded_image = response["image"]
try:
# Decode the base64-encoded image
decoded_image = BytesIO(base64.b64decode(encoded_image))
# Display the image
st.image(decoded_image, caption="Generated Chart", use_column_width=True)
except Exception as e:
st.error(f"Failed to decode and display the image: {str(e)}")
else:
st.error(f"Failed to generate chart: {response.get('error', 'Unknown error')}")
st.markdown("### Example Queries")
st.write("1. Plot GDP of Malaysia over 5 years")
st.write("2. Create a bar chart of sales data")