pictionagent-reachy / pictionagent.py
Thibault Hervier
Small fixes for UX
047f3bd
import os
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY must be set")
import random
from typing import Any
import datetime
import asyncio
from io import BytesIO
from enum import Enum
import json
import requests
import gradio as gr
from PIL import Image, ImageDraw
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage, SystemMessage
from langchain_core.tools import StructuredTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import InMemorySaver
from draw_input_schema import DrawInputSchema, Line, Point
from reachy_api import ReachyAPI
import logging
logging.basicConfig(level=logging.INFO)
script_dir = os.path.dirname(os.path.abspath(__file__))
EMOTION_RECORDINGS_FOLDER = os.path.join(script_dir, "emotion_recordings")
GAME_TIMEOUT = 30 # seconds
game_agent_prompt="""
You are a drawing agent and we are playing a game. I will give you a topic, and you will choose randomly an item to draw that will correspond to the given topic (for example, if the topic is "Animal", the item could be "Dog", "Cat", "Bird", etc.). Do not always choose first the obvious like Elephant or Eiffel Tower and do not choose the same item twice.
The user can only guess during the drawing, after the drawing is completed the user can no longer guess, so do not expect them to, just display the drawing image using provided tool and share the item drawn.
If it was guessed, share what it was and the id of the user who guessed it (if it was guessed).
"""
guess_agent_prompt="""
We are playing a draw and guess game. You have drawn an item and the user tries to guess it.
You must return a message indicating if the user has guessed the item correctly or not, but never tell the item to guess to the user if they have not guessed it correctly.
If the user is very close (for example with some small typos), you can validate the guess.
Answer using json format following this structure:
{
"guessed": true/false,
"response": "response to the user"
}
"""
class ItemTopic(Enum):
ANIMAL = "Animal"
SHAPE = "Shape"
FAMOUS_MONUMENT = "Famous monument"
class AgentStatus(Enum):
IDLE = "Idle"
CHOOSING_ITEM = "Choosing item"
GENERATING_IMAGE_AND_POINTS = "Generating image and points"
DRAWING = "Drawing"
WAITING_FOR_GUESS = "Waiting for guess"
DRAWING_COMPLETED = "Drawing completed"
DRAWING_GUESSED = "Drawing guessed"
class User:
id: str
name: str
last_ping: datetime
score: int
def __init__(self, id: str):
self.id = id
self.last_ping = datetime.datetime.now()
self.name = None
self.score = 0
class PictionagentManager:
fake_robot : bool
async_init_done : bool
nb_drawing_points : int
drawing_duration : float
current_draw_input : DrawInputSchema
drawing_image : Image.Image
image_draw : ImageDraw.ImageDraw
is_drawing : bool
drawn_points : list[Point]
generated_source_image : Image.Image
generated_source_image_url : str
displayed_item_to_guess : str
game_agent : Any
checkpointer : InMemorySaver
config : dict
current_topic : ItemTopic
current_item : str
previous_items : list[str]
agent_status : AgentStatus
is_drawing_guessed : bool
guesser_id : str
agent_history : list[gr.ChatMessage]
game_running : bool
connected_users : dict[str, User]
leaderboard : list[User]
reachy_api : ReachyAPI
EXPECTED_TOOL_NAME_GENERATE_IMAGE_AND_POINTS = "generate_image_and_points"
def __init__(self, fake_robot: bool):
self.fake_robot = fake_robot
self.async_init_done = False
self.nb_drawing_points = 100
self.drawing_duration = 0.05
self.current_draw_input = None
self.drawing_image = None
self.image_draw = None
self.is_drawing = False
self.drawn_points = []
self.generated_source_image = None
self.generated_source_image_url = None
self.displayed_item_to_guess = None
self.game_agent = None
self.checkpointer = None
self.config = None
self.current_topic = ItemTopic.ANIMAL
self.current_item = None
self.previous_items = []
self.agent_status = AgentStatus.IDLE
self.is_drawing_guessed = False
self.guesser_id = None
self.agent_history = []
self.game_running = False
self.connected_users = {}
self.leaderboard = []
self.reachy_api = None
async def async_init(self):
if self.async_init_done:
return
try:
self.game_chat = ChatOpenAI(model="gpt-4o-mini")
mcp_client = MultiServerMCPClient(
{
"gradio-draw-generation": {
"url": "https://agents-mcp-hackathon-concept-to-drawing-points.hf.space/gradio_api/mcp/sse",
"transport": "sse",
}
})
tools = await mcp_client.get_tools()
tools.append(StructuredTool.from_function(coroutine=self.set_status,
name="set_status",
description=self.set_status.__doc__))
tools.append(StructuredTool.from_function(func=self.save_item,
name="save_item",
description=self.save_item.__doc__))
tools.append(StructuredTool.from_function(coroutine=self.draw,
name="draw",
description=self.draw.__doc__))
tools.append(StructuredTool.from_function(func=self.display_drawing_image_and_item,
name="display_drawing_image_and_item", description=self.display_drawing_image_and_item.__doc__))
self.checkpointer = InMemorySaver()
self.config = {"configurable": {"thread_id": "1"}, "recursion_limit": 50}
self.game_agent = create_react_agent(
model=self.game_chat.bind_tools(tools, parallel_tool_calls=False),
tools=tools,
checkpointer=self.checkpointer,
prompt=game_agent_prompt
)
asyncio.create_task(self._check_users())
asyncio.create_task(self._reset_game_after_timeout())
self.reachy_api = ReachyAPI(fake_robot=self.fake_robot,
emotion_recordings_folder=EMOTION_RECORDINGS_FOLDER,
robot_ip="localhost")
await self.reachy_api.connect()
self.async_init_done = True
except Exception as e:
logging.error(f"Error initializing agent: {e}")
raise e
#region Tools
async def set_status(self, status: str) -> str:
"""Set the status agent
Args:
status: The status to set, between "Choosing item", "Generating image and points", "Drawing", "Drawing completed" or "Drawing guessed"
Returns:
A message indicating that the status has been set.
"""
self.agent_status = AgentStatus(status)
if self.agent_status == AgentStatus.CHOOSING_ITEM:
animations = ["thoughtful1", "thoughtful2"]
asyncio.create_task(self.reachy_api.play_emotion(animations[random.randint(0, len(animations) - 1)]))
elif (self.agent_status == AgentStatus.DRAWING_COMPLETED and self.is_drawing_guessed) or self.agent_status == AgentStatus.DRAWING_GUESSED:
animations = ["success2", "enthusistic1"]
asyncio.create_task(self.reachy_api.play_emotion(animations[random.randint(0, len(animations) - 1)]))
elif self.agent_status == AgentStatus.DRAWING_COMPLETED and not self.is_drawing_guessed:
animations = ["disgusted1", "anxiety1"]
asyncio.create_task(self.reachy_api.play_emotion(animations[random.randint(0, len(animations) - 1)]))
self.agent_history.append(gr.ChatMessage(role="assistant", content="(" + status + ")"))
return f"Status set to {status}"
def save_item(self, item: str) -> str:
"""Save the chosen item to draw
Args:
item: The item to save
Returns:
A message indicating that the item has been saved.
"""
self.current_item = item
self.previous_items.append(item)
return f"Item saved"
async def draw(self) -> str:
"""Draw points on the canvas, from previously generated points.
Returns:
A message indicating if the drawing has been completed or aborted.
"""
# Clear the image
self.drawing_image = Image.new('RGB', (512, 512), 'white')
self.image_draw = ImageDraw.Draw(self.drawing_image)
self.drawn_points = []
self.is_drawing = True
drawing_update_task = asyncio.create_task(self.update_drawing())
result = await self.reachy_api.draw(max_x=512, max_y=512,
draw_input=self.current_draw_input,
duration_between_points=self.drawing_duration)
if not self.is_drawing_guessed:
# Workaround: wait for 5 seconds to let users guess
await self.set_status(AgentStatus.WAITING_FOR_GUESS.value)
await asyncio.sleep(5)
self.is_drawing = False
await drawing_update_task
if result:
return "Drawing completed"
else:
return "Drawing stopped as user '" + self.get_user_display_name(self.guesser_id) + "' guessed the drawn item"
def display_drawing_image_and_item(self) -> str:
"""Display the initial image of the drawing and the item drawn at the end of the game. Must be called at the end of the game.
Returns:
A message indicating that the drawing image and item have been displayed.
"""
response = requests.get(self.generated_source_image_url)
self.generated_source_image = Image.open(BytesIO(response.content))
self.displayed_item_to_guess = self.current_item
return "Drawing image and item displayed"
#endregion
async def start_game(self) -> None:
self.reset_game()
self.game_running = True
messages = []
messages.append(SystemMessage(content="Start the game for this topic: " + self.current_topic.value))
messages.append(SystemMessage(content="Previous items: " + str(self.previous_items)))
messages.append(SystemMessage(content="Number of points to draw: " + str(self.nb_drawing_points)))
self.agent_history.append(gr.ChatMessage(role="user", content="Start the game for this topic: " + self.current_topic.value))
# Get response from Agent
try:
async for event in self.game_agent.astream({"messages": messages}, self.config, stream_mode="values"):
response = event["messages"][-1]
response.pretty_print()
if isinstance(response, ToolMessage):
self.handle_tool_message(response)
if isinstance(response, AIMessage):
self.agent_history.append(gr.ChatMessage(role="assistant", content=response.content))
except Exception as e:
logging.error(f"Error in game_agent.astream: {e}")
self.agent_history.append(gr.ChatMessage(role="assistant", content="Error in agent execution"))
self.agent_status = AgentStatus.IDLE
self.game_running = False
def handle_tool_message(self, response: ToolMessage) -> None:
if self.EXPECTED_TOOL_NAME_GENERATE_IMAGE_AND_POINTS in response.name:
# Parse content as list if it's a string
content = response.content
if isinstance(content, str):
content = json.loads(content)
response_image = content[0]
response_points = content[1]
self.generated_source_image_url = response_image.split("Image URL: ")[1]
self.current_draw_input = DrawInputSchema.model_validate_json(response_points)
async def update_drawing(self):
while self.is_drawing:
new_drawn_points = self.reachy_api.get_drawn_positions()
if len(new_drawn_points) > len(self.drawn_points):
# Get new points by slicing from the length of self.drawn_points, to keep only the new points
new_points = new_drawn_points[len(self.drawn_points):]
if len(self.drawn_points) > 0:
last_point = self.drawn_points[-1]
else:
last_point = new_points[0]
for point in new_points:
self.image_draw.line([(last_point.x, last_point.y), (point.x, point.y)], fill='black', width=2)
last_point = point
self.drawn_points = new_drawn_points
await asyncio.sleep(0.1)
async def try_guess_drawing(self, guess_input: str, user_id: str) -> str:
"""Try to guess the drawing
Args:
guess_input: The guess input
Returns:
A response to the user.
"""
messages = []
messages.append(SystemMessage(content=guess_agent_prompt))
messages.append(SystemMessage(content="The item drawn is: " + self.current_item))
messages.append(HumanMessage(content="I believe the item is: " + guess_input))
self.agent_history.append(gr.ChatMessage(role="user", content=("(guess by " + self.get_user_display_name(user_id) + ")\n" + guess_input)))
guess_chat = ChatOpenAI(model="gpt-4o-mini")
response = await guess_chat.ainvoke(messages)
json_response = json.loads(response.content)
guessed = json_response["guessed"]
chat_response = json_response["response"]
self.agent_history.append(gr.ChatMessage(role="assistant", content="(guess) " + chat_response))
if guessed:
self.connected_users[user_id].score += 1
self._update_leaderboard(self.connected_users[user_id])
self.is_drawing_guessed = True
self.guesser_id = user_id
self.reachy_api.stop_drawing()
return chat_response
def _update_leaderboard(self, user: User):
if user not in self.leaderboard:
self.leaderboard.append(user)
self.leaderboard = sorted(self.leaderboard, key=lambda user: user.score, reverse=True)
self.leaderboard = self.leaderboard[:10]
def get_user_display_name(self, user_id: str) -> str:
if user_id in self.connected_users and self.connected_users[user_id].name:
return self.connected_users[user_id].name
else:
user_in_leaderboard = next((user for user in self.leaderboard if user.id == user_id), None)
if user_in_leaderboard and user_in_leaderboard.name:
return user_in_leaderboard.name
else:
return user_id
async def _check_users(self):
while True:
users_to_remove = []
for user in self.connected_users.values():
if user.last_ping < datetime.datetime.now() - datetime.timedelta(seconds=5):
users_to_remove.append(user.id)
logging.info(f"User {user.id} disconnected")
for user_id in users_to_remove:
self.connected_users.pop(user_id)
await asyncio.sleep(1)
def on_user_ping(self, user_id: str):
if user_id not in self.connected_users:
self.connected_users[user_id] = User(id=user_id)
else:
self.connected_users[user_id].last_ping = datetime.datetime.now()
def set_user_name(self, user_id: str, name: str):
if user_id not in self.connected_users:
self.connected_users[user_id] = User(id=user_id)
self.connected_users[user_id].name = name
async def _reset_game_after_timeout(self):
while True:
if self.game_running:
last_time_game_running = datetime.datetime.now()
else:
if self.agent_status != AgentStatus.IDLE and datetime.datetime.now() - last_time_game_running > datetime.timedelta(seconds=GAME_TIMEOUT):
self.reset_game()
await asyncio.sleep(1)
def reset_game(self):
self.agent_status = AgentStatus.IDLE
self.current_draw_input = None
self.drawing_image = None
self.image_draw = None
self.generated_source_image = None
self.generated_source_image_url = None
self.displayed_item_to_guess = None
self.is_drawing_guessed = False
self.guesser_id = None
self.agent_history = []
self.checkpointer.delete_thread("1")