Update app.py
Browse files
app.py
CHANGED
|
@@ -1,108 +1,148 @@
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
-
|
|
|
|
| 4 |
import requests
|
| 5 |
-
from typing import Dict, List
|
| 6 |
-
from langchain_core.messages import HumanMessage
|
| 7 |
from langchain_core.tools import tool
|
| 8 |
from langchain_openai import ChatOpenAI
|
| 9 |
from langgraph.checkpoint.memory import MemorySaver
|
| 10 |
from langgraph.prebuilt import create_react_agent
|
| 11 |
|
| 12 |
-
#
|
| 13 |
@tool
|
| 14 |
def get_lat_lng(location_description: str) -> dict[str, float]:
|
| 15 |
"""Get the latitude and longitude of a location."""
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
@tool
|
| 19 |
def get_weather(lat: float, lng: float) -> dict[str, str]:
|
| 20 |
"""Get the weather at a location."""
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
llm = ChatOpenAI(temperature=0, model="gpt-4")
|
| 29 |
-
memory = MemorySaver()
|
| 30 |
tools = [get_lat_lng, get_weather]
|
| 31 |
agent_executor = create_react_agent(llm, tools, checkpointer=memory)
|
| 32 |
-
|
| 33 |
-
#
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
{"messages": past_messages},
|
| 44 |
-
config={"configurable": {"thread_id": "abc123"}}
|
| 45 |
):
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
yield messages_to_display
|
| 79 |
-
|
| 80 |
-
# Create the Gradio interface
|
| 81 |
demo = gr.ChatInterface(
|
| 82 |
fn=stream_from_agent,
|
| 83 |
-
type="messages"
|
| 84 |
title="🌤️ Weather Assistant",
|
| 85 |
description="Ask about the weather anywhere! Watch as I gather the information step by step.",
|
| 86 |
examples=[
|
| 87 |
-
"What's the weather like in Tokyo?",
|
| 88 |
-
"Is it sunny in Paris right now?",
|
| 89 |
-
"Should I bring an umbrella in New York today?"
|
| 90 |
],
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
"https://cdn2.iconfinder.com/data/icons/city-icons-for-offscreen-magazine/80/new-york-256.png"
|
| 94 |
-
],
|
| 95 |
save_history=True,
|
| 96 |
-
editable=True
|
| 97 |
-
|
| 98 |
)
|
| 99 |
|
| 100 |
if __name__ == "__main__":
|
| 101 |
# Load environment variables
|
| 102 |
try:
|
| 103 |
from dotenv import load_dotenv
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
except ImportError:
|
|
|
|
| 106 |
pass
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
+
# Keep using gradio.ChatMessage for type hints if needed, but not for yielding complex structures directly to ChatInterface
|
| 4 |
+
# from gradio import ChatMessage # Maybe remove this import if not used elsewhere
|
| 5 |
import requests
|
| 6 |
+
from typing import Dict, List, AsyncGenerator, Union, Tuple
|
| 7 |
+
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage # Use LangChain messages internally
|
| 8 |
from langchain_core.tools import tool
|
| 9 |
from langchain_openai import ChatOpenAI
|
| 10 |
from langgraph.checkpoint.memory import MemorySaver
|
| 11 |
from langgraph.prebuilt import create_react_agent
|
| 12 |
|
| 13 |
+
# --- Tools remain the same ---
|
| 14 |
@tool
|
| 15 |
def get_lat_lng(location_description: str) -> dict[str, float]:
|
| 16 |
"""Get the latitude and longitude of a location."""
|
| 17 |
+
print(f"Tool: Getting lat/lng for {location_description}")
|
| 18 |
+
# Replace with actual API call in a real app
|
| 19 |
+
if "tokyo" in location_description.lower():
|
| 20 |
+
return {"lat": 35.6895, "lng": 139.6917}
|
| 21 |
+
elif "paris" in location_description.lower():
|
| 22 |
+
return {"lat": 48.8566, "lng": 2.3522}
|
| 23 |
+
elif "new york" in location_description.lower():
|
| 24 |
+
return {"lat": 40.7128, "lng": -74.0060}
|
| 25 |
+
else:
|
| 26 |
+
return {"lat": 51.5072, "lng": -0.1276} # Default London
|
| 27 |
|
| 28 |
@tool
|
| 29 |
def get_weather(lat: float, lng: float) -> dict[str, str]:
|
| 30 |
"""Get the weather at a location."""
|
| 31 |
+
print(f"Tool: Getting weather for lat={lat}, lng={lng}")
|
| 32 |
+
# Replace with actual API call in a real app
|
| 33 |
+
# Dummy logic based on lat
|
| 34 |
+
if lat > 45: # Northern locations
|
| 35 |
+
return {"temperature": "15°C", "description": "Cloudy"}
|
| 36 |
+
elif lat > 30: # Mid locations
|
| 37 |
+
return {"temperature": "25°C", "description": "Sunny"}
|
| 38 |
+
else: # Southern locations
|
| 39 |
+
return {"temperature": "30°C", "description": "Very Sunny"}
|
| 40 |
|
| 41 |
+
# --- Modified Agent Function ---
|
| 42 |
+
# Change return type hint for clarity if desired, e.g., AsyncGenerator[str, None]
|
| 43 |
+
# Or keep it simple, Gradio infers based on yields
|
| 44 |
+
async def stream_from_agent(message: str, history: List[List[str]]) -> AsyncGenerator[str, None]:
|
| 45 |
+
"""Processes message through LangChain agent, yielding intermediate steps as strings."""
|
| 46 |
+
|
| 47 |
+
# Convert Gradio history to LangChain messages
|
| 48 |
+
lc_messages = []
|
| 49 |
+
for user_msg, ai_msg in history:
|
| 50 |
+
if user_msg:
|
| 51 |
+
lc_messages.append(HumanMessage(content=user_msg))
|
| 52 |
+
if ai_msg:
|
| 53 |
+
# Important: Handle potential previous intermediate strings from AI
|
| 54 |
+
# If the ai_msg contains markers like "🛠️ Using", it was an intermediate step.
|
| 55 |
+
# For simplicity here, we assume full AI responses were stored previously.
|
| 56 |
+
# A more robust solution might involve storing message types in history.
|
| 57 |
+
if not ai_msg.startswith("🛠️ Using") and not ai_msg.startswith("Result:"):
|
| 58 |
+
lc_messages.append(AIMessage(content=ai_msg))
|
| 59 |
+
|
| 60 |
+
lc_messages.append(HumanMessage(content=message))
|
| 61 |
+
|
| 62 |
+
# Initialize the agent (consider initializing outside the function if stateful across calls)
|
| 63 |
llm = ChatOpenAI(temperature=0, model="gpt-4")
|
| 64 |
+
memory = MemorySaver() # Be mindful of memory state if agent is re-initialized every time
|
| 65 |
tools = [get_lat_lng, get_weather]
|
| 66 |
agent_executor = create_react_agent(llm, tools, checkpointer=memory)
|
| 67 |
+
|
| 68 |
+
# Use a unique thread_id per session if needed, or manage state differently
|
| 69 |
+
# Using a fixed one like "abc123" means all users share the same memory if server restarts aren't frequent
|
| 70 |
+
thread_id = "user_session_" + str(os.urandom(4).hex()) # Example: generate unique ID
|
| 71 |
+
|
| 72 |
+
full_response = "" # Accumulate the response parts
|
| 73 |
+
|
| 74 |
+
async for chunk in agent_executor.astream_events(
|
| 75 |
+
{"messages": lc_messages},
|
| 76 |
+
config={"configurable": {"thread_id": thread_id}},
|
| 77 |
+
version="v1" # Use v1 for events streaming
|
|
|
|
|
|
|
| 78 |
):
|
| 79 |
+
event = chunk["event"]
|
| 80 |
+
data = chunk["data"]
|
| 81 |
+
|
| 82 |
+
if event == "on_chat_model_stream":
|
| 83 |
+
# Stream content from the LLM (final answer parts)
|
| 84 |
+
content = data["chunk"].content
|
| 85 |
+
if content:
|
| 86 |
+
full_response += content
|
| 87 |
+
yield full_response # Yield the accumulating final response
|
| 88 |
+
|
| 89 |
+
elif event == "on_tool_start":
|
| 90 |
+
# Show tool usage start
|
| 91 |
+
tool_input_str = str(data.get('input', '')) # Get tool input safely
|
| 92 |
+
yield f"🛠️ Using tool: **{data['name']}** with input: `{tool_input_str}`"
|
| 93 |
+
|
| 94 |
+
elif event == "on_tool_end":
|
| 95 |
+
# Show tool result (optional, can make chat verbose)
|
| 96 |
+
tool_output_str = str(data.get('output', '')) # Get tool output safely
|
| 97 |
+
# Find the corresponding start message to potentially update, or just yield new message
|
| 98 |
+
# For simplicity, just yield the result as a new message line
|
| 99 |
+
yield f"Tool **{data['name']}** finished.\nResult: `{tool_output_str}`"
|
| 100 |
+
# Yield the accumulated response again after tool use in case LLM continues
|
| 101 |
+
if full_response:
|
| 102 |
+
yield full_response
|
| 103 |
+
|
| 104 |
+
# Ensure the final accumulated response is yielded if not already done by the last LLM chunk
|
| 105 |
+
# (stream might end on tool end sometimes)
|
| 106 |
+
if full_response and (not chunk or chunk["event"] != "on_chat_model_stream"):
|
| 107 |
+
yield full_response
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# --- Gradio Interface (mostly unchanged) ---
|
|
|
|
|
|
|
|
|
|
| 111 |
demo = gr.ChatInterface(
|
| 112 |
fn=stream_from_agent,
|
| 113 |
+
# No type="messages" needed when yielding strings; ChatInterface handles it.
|
| 114 |
title="🌤️ Weather Assistant",
|
| 115 |
description="Ask about the weather anywhere! Watch as I gather the information step by step.",
|
| 116 |
examples=[
|
| 117 |
+
["What's the weather like in Tokyo?"],
|
| 118 |
+
["Is it sunny in Paris right now?"],
|
| 119 |
+
["Should I bring an umbrella in New York today?"]
|
| 120 |
],
|
| 121 |
+
# Example icons removed for simplicity, ensure they are accessible if added back
|
| 122 |
+
cache_examples=False, # Turn off caching initially to ensure it's not the issue
|
|
|
|
|
|
|
| 123 |
save_history=True,
|
| 124 |
+
editable=True,
|
|
|
|
| 125 |
)
|
| 126 |
|
| 127 |
if __name__ == "__main__":
|
| 128 |
# Load environment variables
|
| 129 |
try:
|
| 130 |
from dotenv import load_dotenv
|
| 131 |
+
print("Attempting to load .env file...")
|
| 132 |
+
loaded = load_dotenv()
|
| 133 |
+
if loaded:
|
| 134 |
+
print(".env file loaded successfully.")
|
| 135 |
+
else:
|
| 136 |
+
print(".env file not found or empty.")
|
| 137 |
+
# Check if the key is loaded
|
| 138 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 139 |
+
if openai_api_key:
|
| 140 |
+
print("OPENAI_API_KEY found.")
|
| 141 |
+
else:
|
| 142 |
+
print("Warning: OPENAI_API_KEY not found in environment variables.")
|
| 143 |
except ImportError:
|
| 144 |
+
print("dotenv not installed, skipping .env load.")
|
| 145 |
pass
|
| 146 |
|
| 147 |
+
# Add server_name="0.0.0.0" if running in Docker or need external access
|
| 148 |
+
demo.launch(debug=True, server_name="0.0.0.0")
|