DrishtiSharma commited on
Commit
db0c72e
·
verified ·
1 Parent(s): f6e2df7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -148
app.py CHANGED
@@ -1,30 +1,16 @@
1
  import os
2
- import json
3
- import re
4
  import base64
5
- import streamlit as st
6
  from io import BytesIO
7
- from langchain_core.utils.function_calling import convert_to_openai_function
8
- from langchain_core.messages import (
9
- AIMessage,
10
- BaseMessage,
11
- ChatMessage,
12
- FunctionMessage,
13
- HumanMessage,
14
- )
15
- from langchain.tools.render import format_tool_to_openai_function
16
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
17
- from langgraph.graph import END, StateGraph
18
- from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
19
  from langchain_core.tools import tool
 
 
20
  from langchain_community.tools.tavily_search import TavilySearchResults
21
  from langchain_experimental.utilities import PythonREPL
22
- from langchain_openai import ChatOpenAI
23
- from typing import Annotated, Sequence
24
- from typing_extensions import TypedDict
25
- import operator
26
- import functools
27
- import matplotlib.pyplot as plt
28
 
29
  # Set up environment variables for API keys
30
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
@@ -35,20 +21,14 @@ if not TAVILY_API_KEY or not OPENAI_API_KEY:
35
  st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.")
36
  st.stop()
37
 
38
- # Define the AgentState class
39
- class AgentState(TypedDict):
40
- messages: Annotated[Sequence[BaseMessage], operator.add]
41
- sender: str
42
-
43
  # Initialize tools
44
  tavily_tool = TavilySearchResults(max_results=5)
45
- repl = PythonREPL()
46
 
47
  @tool
48
- def python_repl(code: Annotated[str, "The python code to execute to generate your chart."]):
49
  """Executes Python code to generate a chart and returns the chart as a base64-encoded image."""
50
  try:
51
- # Execute the code
52
  exec_globals = {"plt": plt}
53
  exec_locals = {}
54
  exec(code, exec_globals, exec_locals)
@@ -66,130 +46,53 @@ def python_repl(code: Annotated[str, "The python code to execute to generate you
66
  encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8")
67
  return {"status": "success", "image": encoded_image}
68
  except Exception as e:
69
- return {"status": "failed", "error": repr(e)}
70
 
71
  tools = [tavily_tool, python_repl]
72
-
73
- # Define a tool executor
74
  tool_executor = ToolExecutor(tools)
75
 
76
- # Define tool node
77
- def tool_node(state):
78
- """Executes tools in the graph."""
79
- messages = state["messages"]
80
- last_message = messages[-1]
81
- tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"])
82
- if len(tool_input) == 1 and "__arg1" in tool_input:
83
- tool_input = next(iter(tool_input.values()))
84
- tool_name = last_message.additional_kwargs["function_call"]["name"]
85
- action = ToolInvocation(tool=tool_name, tool_input=tool_input)
86
- response = tool_executor.invoke(action)
87
- if isinstance(response, dict) and response.get("status") == "success" and "image" in response:
88
- return {
89
- "messages": [
90
- {
91
- "role": "assistant",
92
- "content": "Image generated successfully.",
93
- "image": response["image"],
94
- }
95
- ]
96
- }
97
- else:
98
- function_message = FunctionMessage(
99
- content=f"{tool_name} response: {str(response)}", name=action.tool
100
- )
101
- return {"messages": [function_message]}
102
-
103
- # Define router
104
- def router(state):
105
- """Determines the next step in the workflow."""
106
- messages = state["messages"]
107
- last_message = messages[-1]
108
- if "function_call" in last_message.additional_kwargs:
109
- return "call_tool"
110
- if "FINAL ANSWER" in last_message.content:
111
- return "end"
112
- return "continue"
113
-
114
- # Define agent creation function
115
- def create_agent(llm, tools, system_message: str):
116
- """Creates an agent."""
117
- functions = [convert_to_openai_function(t) for t in tools]
118
- prompt = ChatPromptTemplate.from_messages(
119
- [
120
- (
121
- "system",
122
- "You are a helpful AI assistant, collaborating with other assistants."
123
- " Use the provided tools to progress towards answering the question."
124
- " If you are unable to fully answer, that's OK, another assistant with different tools "
125
- " will help where you left off. Execute what you can to make progress."
126
- " If you or any of the other assistants have the final answer or deliverable,"
127
- " prefix your response with FINAL ANSWER so the team knows to stop."
128
- " You have access to the following tools: {tool_names}.\n{system_message}",
129
- ),
130
- MessagesPlaceholder(variable_name="messages"),
131
- ]
132
- )
133
- prompt = prompt.partial(system_message=system_message)
134
- prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
135
- return prompt | llm.bind_functions(functions)
136
-
137
- # Define agent node
138
- def agent_node(state, agent, name):
139
- result = agent.invoke(state)
140
- if isinstance(result, FunctionMessage):
141
- pass
142
- else:
143
- # Sanitize the name field to match OpenAI's naming conventions
144
- sanitized_name = re.sub(r"[^a-zA-Z0-9_-]", "_", name)
145
- result = HumanMessage(**result.dict(exclude={"type", "name"}), name=sanitized_name)
146
- return {"messages": [result], "sender": name}
147
-
148
- # Initialize LLM
149
- llm = ChatOpenAI(api_key=OPENAI_API_KEY)
150
-
151
- # Create agents
152
- research_agent = create_agent(
153
- llm, [tavily_tool], system_message="You should provide accurate data for the chart generator to use."
154
- )
155
- chart_agent = create_agent(
156
- llm, [python_repl], system_message="Any charts you display will be visible by the user."
157
- )
158
-
159
- # Define workflow graph
160
- workflow = StateGraph(AgentState)
161
- workflow.add_node("Researcher", functools.partial(agent_node, agent=research_agent, name="Researcher"))
162
- workflow.add_node("Chart Generator", functools.partial(agent_node, agent=chart_agent, name="Chart Generator"))
163
- workflow.add_node("call_tool", tool_node)
164
- workflow.add_conditional_edges("Researcher", router, {"continue": "Chart Generator", "call_tool": "call_tool", "end": END})
165
- workflow.add_conditional_edges("Chart Generator", router, {"continue": "Researcher", "call_tool": "call_tool", "end": END})
166
- workflow.add_conditional_edges("call_tool", lambda x: x["sender"], {"Researcher": "Researcher", "Chart Generator": "Chart Generator"})
167
- workflow.set_entry_point("Researcher")
168
- graph = workflow.compile()
169
 
170
  # Streamlit UI
171
  st.title("Multi-Agent Workflow")
172
- user_query = st.text_area("Enter your query:", "Fetch Malaysia's GDP over the past 5 years and draw a line graph.")
173
- if st.button("Run Workflow"):
174
- st.write("Running workflow...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with st.spinner("Processing..."):
176
- try:
177
- messages = [HumanMessage(content=user_query)]
178
- for step in graph.stream({"messages": messages}, {"recursion_limit": 150}):
179
- st.write("Step Details:", step)
180
- if "messages" in step:
181
- for message in step["messages"]:
182
- if "image" in message:
183
- try:
184
- # Decode the base64-encoded image
185
- encoded_image = message["image"]
186
- decoded_image = BytesIO(base64.b64decode(encoded_image))
187
- # Display the image
188
- st.image(decoded_image, caption="Generated Chart", use_column_width=True)
189
- except Exception as e:
190
- st.error(f"Failed to decode and display the image: {repr(e)}")
191
- elif "content" in message:
192
- # Display any text content
193
- st.write(message["content"])
194
- except Exception as e:
195
- st.error(f"An error occurred: {e}")
 
1
  import os
 
 
2
  import base64
 
3
  from io import BytesIO
4
+
5
+ import streamlit as st
6
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
7
  from langchain_core.tools import tool
8
+ from langchain_core.utils.function_calling import convert_to_openai_function
9
+ from langchain_openai import ChatOpenAI
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_experimental.utilities import PythonREPL
12
+ from langgraph.graph import StateGraph, END
13
+ from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
 
 
 
 
14
 
15
  # Set up environment variables for API keys
16
  TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
 
21
  st.error("API keys are missing. Please set TAVILY_API_KEY and OPENAI_API_KEY as secrets.")
22
  st.stop()
23
 
 
 
 
 
 
24
  # Initialize tools
25
  tavily_tool = TavilySearchResults(max_results=5)
 
26
 
27
  @tool
28
+ def python_repl(code: str):
29
  """Executes Python code to generate a chart and returns the chart as a base64-encoded image."""
30
  try:
31
+ # Execute the provided Python code
32
  exec_globals = {"plt": plt}
33
  exec_locals = {}
34
  exec(code, exec_globals, exec_locals)
 
46
  encoded_image = base64.b64encode(buf.getvalue()).decode("utf-8")
47
  return {"status": "success", "image": encoded_image}
48
  except Exception as e:
49
+ return {"status": "failed", "error": str(e)}
50
 
51
  tools = [tavily_tool, python_repl]
 
 
52
  tool_executor = ToolExecutor(tools)
53
 
54
+ # Define the multi-agent workflow
55
+ workflow = StateGraph()
56
+ workflow.add_node("call_tool", lambda state: tool_executor.invoke(ToolInvocation(**state)))
57
+ workflow.set_entry_point("call_tool")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Streamlit UI
60
  st.title("Multi-Agent Workflow")
61
+ st.markdown("### Generate a Chart from Python Code")
62
+ code_input = st.text_area(
63
+ "Enter Python code for the chart:",
64
+ """
65
+ import matplotlib.pyplot as plt
66
+ years = [2019, 2020, 2021, 2022, 2023]
67
+ gdp = [300, 310, 330, 360, 399]
68
+ plt.figure(figsize=(10, 6))
69
+ plt.plot(years, gdp, marker='o', color='b', linestyle='-')
70
+ plt.title('Malaysia GDP Over 5 Years')
71
+ plt.xlabel('Year')
72
+ plt.ylabel('GDP (in billion USD)')
73
+ plt.grid(True)
74
+ """
75
+ )
76
+
77
+ if st.button("Generate Chart"):
78
+ st.write("Generating chart...")
79
  with st.spinner("Processing..."):
80
+ # Invoke the Python REPL tool with the code
81
+ response = python_repl(code_input)
82
+
83
+ # Check response status
84
+ if response["status"] == "success" and "image" in response:
85
+ encoded_image = response["image"]
86
+ try:
87
+ # Decode the base64-encoded image
88
+ decoded_image = BytesIO(base64.b64decode(encoded_image))
89
+ # Display the image
90
+ st.image(decoded_image, caption="Generated Chart", use_column_width=True)
91
+ except Exception as e:
92
+ st.error(f"Failed to decode and display the image: {str(e)}")
93
+ else:
94
+ st.error(f"Failed to generate chart: {response.get('error', 'Unknown error')}")
95
+
96
+ st.markdown("### Example Queries")
97
+ st.write("1. Plot GDP of Malaysia over 5 years")
98
+ st.write("2. Create a bar chart of sales data")