|
import os |
|
if not os.getenv("OPENAI_API_KEY"): |
|
raise ValueError("OPENAI_API_KEY must be set") |
|
|
|
import random |
|
from typing import Any |
|
import datetime |
|
|
|
import asyncio |
|
from io import BytesIO |
|
from enum import Enum |
|
import json |
|
import requests |
|
import gradio as gr |
|
from PIL import Image, ImageDraw |
|
from langchain_core.messages import HumanMessage, ToolMessage, AIMessage, SystemMessage |
|
from langchain_core.tools import StructuredTool |
|
from langchain_mcp_adapters.client import MultiServerMCPClient |
|
from langchain_openai import ChatOpenAI |
|
from langgraph.prebuilt import create_react_agent |
|
from langgraph.checkpoint.memory import InMemorySaver |
|
from draw_input_schema import DrawInputSchema, Line, Point |
|
from reachy_api import ReachyAPI |
|
|
|
import logging |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
EMOTION_RECORDINGS_FOLDER = os.path.join(script_dir, "emotion_recordings") |
|
|
|
GAME_TIMEOUT = 30 |
|
|
|
game_agent_prompt=""" |
|
You are a drawing agent and we are playing a game. I will give you a topic, and you will choose randomly an item to draw that will correspond to the given topic (for example, if the topic is "Animal", the item could be "Dog", "Cat", "Bird", etc.). Do not always choose first the obvious like Elephant or Eiffel Tower and do not choose the same item twice. |
|
The user can only guess during the drawing, after the drawing is completed the user can no longer guess, so do not expect them to, just display the drawing image using provided tool and share the item drawn. |
|
If it was guessed, share what it was and the id of the user who guessed it (if it was guessed). |
|
""" |
|
|
|
guess_agent_prompt=""" |
|
We are playing a draw and guess game. You have drawn an item and the user tries to guess it. |
|
You must return a message indicating if the user has guessed the item correctly or not, but never tell the item to guess to the user if they have not guessed it correctly. |
|
If the user is very close (for example with some small typos), you can validate the guess. |
|
|
|
Answer using json format following this structure: |
|
{ |
|
"guessed": true/false, |
|
"response": "response to the user" |
|
} |
|
""" |
|
|
|
class ItemTopic(Enum): |
|
ANIMAL = "Animal" |
|
SHAPE = "Shape" |
|
FAMOUS_MONUMENT = "Famous monument" |
|
|
|
class AgentStatus(Enum): |
|
IDLE = "Idle" |
|
CHOOSING_ITEM = "Choosing item" |
|
GENERATING_IMAGE_AND_POINTS = "Generating image and points" |
|
DRAWING = "Drawing" |
|
WAITING_FOR_GUESS = "Waiting for guess" |
|
DRAWING_COMPLETED = "Drawing completed" |
|
DRAWING_GUESSED = "Drawing guessed" |
|
class User: |
|
id: str |
|
name: str |
|
last_ping: datetime |
|
score: int |
|
|
|
def __init__(self, id: str): |
|
self.id = id |
|
self.last_ping = datetime.datetime.now() |
|
self.name = None |
|
self.score = 0 |
|
|
|
class PictionagentManager: |
|
fake_robot : bool |
|
async_init_done : bool |
|
nb_drawing_points : int |
|
drawing_duration : float |
|
current_draw_input : DrawInputSchema |
|
drawing_image : Image.Image |
|
image_draw : ImageDraw.ImageDraw |
|
is_drawing : bool |
|
drawn_points : list[Point] |
|
generated_source_image : Image.Image |
|
generated_source_image_url : str |
|
displayed_item_to_guess : str |
|
game_agent : Any |
|
checkpointer : InMemorySaver |
|
config : dict |
|
current_topic : ItemTopic |
|
current_item : str |
|
previous_items : list[str] |
|
agent_status : AgentStatus |
|
is_drawing_guessed : bool |
|
guesser_id : str |
|
agent_history : list[gr.ChatMessage] |
|
game_running : bool |
|
connected_users : dict[str, User] |
|
leaderboard : list[User] |
|
reachy_api : ReachyAPI |
|
EXPECTED_TOOL_NAME_GENERATE_IMAGE_AND_POINTS = "generate_image_and_points" |
|
|
|
def __init__(self, fake_robot: bool): |
|
self.fake_robot = fake_robot |
|
self.async_init_done = False |
|
self.nb_drawing_points = 100 |
|
self.drawing_duration = 0.05 |
|
self.current_draw_input = None |
|
self.drawing_image = None |
|
self.image_draw = None |
|
self.is_drawing = False |
|
self.drawn_points = [] |
|
self.generated_source_image = None |
|
self.generated_source_image_url = None |
|
self.displayed_item_to_guess = None |
|
self.game_agent = None |
|
self.checkpointer = None |
|
self.config = None |
|
self.current_topic = ItemTopic.ANIMAL |
|
self.current_item = None |
|
self.previous_items = [] |
|
self.agent_status = AgentStatus.IDLE |
|
self.is_drawing_guessed = False |
|
self.guesser_id = None |
|
self.agent_history = [] |
|
self.game_running = False |
|
self.connected_users = {} |
|
self.leaderboard = [] |
|
self.reachy_api = None |
|
|
|
async def async_init(self): |
|
if self.async_init_done: |
|
return |
|
|
|
try: |
|
self.game_chat = ChatOpenAI(model="gpt-4o-mini") |
|
mcp_client = MultiServerMCPClient( |
|
{ |
|
"gradio-draw-generation": { |
|
"url": "https://agents-mcp-hackathon-concept-to-drawing-points.hf.space/gradio_api/mcp/sse", |
|
"transport": "sse", |
|
} |
|
}) |
|
tools = await mcp_client.get_tools() |
|
tools.append(StructuredTool.from_function(coroutine=self.set_status, |
|
name="set_status", |
|
description=self.set_status.__doc__)) |
|
tools.append(StructuredTool.from_function(func=self.save_item, |
|
name="save_item", |
|
description=self.save_item.__doc__)) |
|
tools.append(StructuredTool.from_function(coroutine=self.draw, |
|
name="draw", |
|
description=self.draw.__doc__)) |
|
tools.append(StructuredTool.from_function(func=self.display_drawing_image_and_item, |
|
name="display_drawing_image_and_item", description=self.display_drawing_image_and_item.__doc__)) |
|
self.checkpointer = InMemorySaver() |
|
|
|
self.config = {"configurable": {"thread_id": "1"}, "recursion_limit": 50} |
|
|
|
self.game_agent = create_react_agent( |
|
model=self.game_chat.bind_tools(tools, parallel_tool_calls=False), |
|
tools=tools, |
|
checkpointer=self.checkpointer, |
|
prompt=game_agent_prompt |
|
) |
|
|
|
asyncio.create_task(self._check_users()) |
|
asyncio.create_task(self._reset_game_after_timeout()) |
|
|
|
self.reachy_api = ReachyAPI(fake_robot=self.fake_robot, |
|
emotion_recordings_folder=EMOTION_RECORDINGS_FOLDER, |
|
robot_ip="localhost") |
|
await self.reachy_api.connect() |
|
|
|
self.async_init_done = True |
|
except Exception as e: |
|
logging.error(f"Error initializing agent: {e}") |
|
raise e |
|
|
|
|
|
async def set_status(self, status: str) -> str: |
|
"""Set the status agent |
|
|
|
Args: |
|
status: The status to set, between "Choosing item", "Generating image and points", "Drawing", "Drawing completed" or "Drawing guessed" |
|
|
|
Returns: |
|
A message indicating that the status has been set. |
|
""" |
|
self.agent_status = AgentStatus(status) |
|
|
|
if self.agent_status == AgentStatus.CHOOSING_ITEM: |
|
animations = ["thoughtful1", "thoughtful2"] |
|
asyncio.create_task(self.reachy_api.play_emotion(animations[random.randint(0, len(animations) - 1)])) |
|
elif (self.agent_status == AgentStatus.DRAWING_COMPLETED and self.is_drawing_guessed) or self.agent_status == AgentStatus.DRAWING_GUESSED: |
|
animations = ["success2", "enthusistic1"] |
|
asyncio.create_task(self.reachy_api.play_emotion(animations[random.randint(0, len(animations) - 1)])) |
|
elif self.agent_status == AgentStatus.DRAWING_COMPLETED and not self.is_drawing_guessed: |
|
animations = ["disgusted1", "anxiety1"] |
|
asyncio.create_task(self.reachy_api.play_emotion(animations[random.randint(0, len(animations) - 1)])) |
|
|
|
self.agent_history.append(gr.ChatMessage(role="assistant", content="(" + status + ")")) |
|
return f"Status set to {status}" |
|
|
|
def save_item(self, item: str) -> str: |
|
"""Save the chosen item to draw |
|
|
|
Args: |
|
item: The item to save |
|
|
|
Returns: |
|
A message indicating that the item has been saved. |
|
""" |
|
self.current_item = item |
|
self.previous_items.append(item) |
|
return f"Item saved" |
|
|
|
async def draw(self) -> str: |
|
"""Draw points on the canvas, from previously generated points. |
|
|
|
Returns: |
|
A message indicating if the drawing has been completed or aborted. |
|
""" |
|
|
|
self.drawing_image = Image.new('RGB', (512, 512), 'white') |
|
self.image_draw = ImageDraw.Draw(self.drawing_image) |
|
self.drawn_points = [] |
|
self.is_drawing = True |
|
|
|
drawing_update_task = asyncio.create_task(self.update_drawing()) |
|
result = await self.reachy_api.draw(max_x=512, max_y=512, |
|
draw_input=self.current_draw_input, |
|
duration_between_points=self.drawing_duration) |
|
|
|
if not self.is_drawing_guessed: |
|
|
|
await self.set_status(AgentStatus.WAITING_FOR_GUESS.value) |
|
await asyncio.sleep(5) |
|
|
|
self.is_drawing = False |
|
await drawing_update_task |
|
|
|
if result: |
|
return "Drawing completed" |
|
else: |
|
return "Drawing stopped as user '" + self.get_user_display_name(self.guesser_id) + "' guessed the drawn item" |
|
|
|
def display_drawing_image_and_item(self) -> str: |
|
"""Display the initial image of the drawing and the item drawn at the end of the game. Must be called at the end of the game. |
|
|
|
Returns: |
|
A message indicating that the drawing image and item have been displayed. |
|
""" |
|
response = requests.get(self.generated_source_image_url) |
|
self.generated_source_image = Image.open(BytesIO(response.content)) |
|
self.displayed_item_to_guess = self.current_item |
|
|
|
return "Drawing image and item displayed" |
|
|
|
|
|
|
|
async def start_game(self) -> None: |
|
self.reset_game() |
|
|
|
self.game_running = True |
|
|
|
messages = [] |
|
messages.append(SystemMessage(content="Start the game for this topic: " + self.current_topic.value)) |
|
messages.append(SystemMessage(content="Previous items: " + str(self.previous_items))) |
|
messages.append(SystemMessage(content="Number of points to draw: " + str(self.nb_drawing_points))) |
|
|
|
self.agent_history.append(gr.ChatMessage(role="user", content="Start the game for this topic: " + self.current_topic.value)) |
|
|
|
|
|
try: |
|
async for event in self.game_agent.astream({"messages": messages}, self.config, stream_mode="values"): |
|
response = event["messages"][-1] |
|
response.pretty_print() |
|
if isinstance(response, ToolMessage): |
|
self.handle_tool_message(response) |
|
if isinstance(response, AIMessage): |
|
self.agent_history.append(gr.ChatMessage(role="assistant", content=response.content)) |
|
except Exception as e: |
|
logging.error(f"Error in game_agent.astream: {e}") |
|
self.agent_history.append(gr.ChatMessage(role="assistant", content="Error in agent execution")) |
|
self.agent_status = AgentStatus.IDLE |
|
|
|
self.game_running = False |
|
|
|
def handle_tool_message(self, response: ToolMessage) -> None: |
|
if self.EXPECTED_TOOL_NAME_GENERATE_IMAGE_AND_POINTS in response.name: |
|
|
|
content = response.content |
|
if isinstance(content, str): |
|
content = json.loads(content) |
|
response_image = content[0] |
|
response_points = content[1] |
|
self.generated_source_image_url = response_image.split("Image URL: ")[1] |
|
self.current_draw_input = DrawInputSchema.model_validate_json(response_points) |
|
|
|
async def update_drawing(self): |
|
while self.is_drawing: |
|
new_drawn_points = self.reachy_api.get_drawn_positions() |
|
if len(new_drawn_points) > len(self.drawn_points): |
|
|
|
new_points = new_drawn_points[len(self.drawn_points):] |
|
if len(self.drawn_points) > 0: |
|
last_point = self.drawn_points[-1] |
|
else: |
|
last_point = new_points[0] |
|
|
|
for point in new_points: |
|
self.image_draw.line([(last_point.x, last_point.y), (point.x, point.y)], fill='black', width=2) |
|
last_point = point |
|
self.drawn_points = new_drawn_points |
|
await asyncio.sleep(0.1) |
|
|
|
async def try_guess_drawing(self, guess_input: str, user_id: str) -> str: |
|
"""Try to guess the drawing |
|
|
|
Args: |
|
guess_input: The guess input |
|
|
|
Returns: |
|
A response to the user. |
|
""" |
|
messages = [] |
|
messages.append(SystemMessage(content=guess_agent_prompt)) |
|
messages.append(SystemMessage(content="The item drawn is: " + self.current_item)) |
|
messages.append(HumanMessage(content="I believe the item is: " + guess_input)) |
|
|
|
self.agent_history.append(gr.ChatMessage(role="user", content=("(guess by " + self.get_user_display_name(user_id) + ")\n" + guess_input))) |
|
|
|
|
|
guess_chat = ChatOpenAI(model="gpt-4o-mini") |
|
response = await guess_chat.ainvoke(messages) |
|
json_response = json.loads(response.content) |
|
|
|
guessed = json_response["guessed"] |
|
chat_response = json_response["response"] |
|
self.agent_history.append(gr.ChatMessage(role="assistant", content="(guess) " + chat_response)) |
|
|
|
if guessed: |
|
self.connected_users[user_id].score += 1 |
|
self._update_leaderboard(self.connected_users[user_id]) |
|
self.is_drawing_guessed = True |
|
self.guesser_id = user_id |
|
self.reachy_api.stop_drawing() |
|
return chat_response |
|
|
|
def _update_leaderboard(self, user: User): |
|
if user not in self.leaderboard: |
|
self.leaderboard.append(user) |
|
self.leaderboard = sorted(self.leaderboard, key=lambda user: user.score, reverse=True) |
|
self.leaderboard = self.leaderboard[:10] |
|
|
|
def get_user_display_name(self, user_id: str) -> str: |
|
if user_id in self.connected_users and self.connected_users[user_id].name: |
|
return self.connected_users[user_id].name |
|
else: |
|
user_in_leaderboard = next((user for user in self.leaderboard if user.id == user_id), None) |
|
if user_in_leaderboard and user_in_leaderboard.name: |
|
return user_in_leaderboard.name |
|
else: |
|
return user_id |
|
|
|
async def _check_users(self): |
|
while True: |
|
users_to_remove = [] |
|
for user in self.connected_users.values(): |
|
if user.last_ping < datetime.datetime.now() - datetime.timedelta(seconds=5): |
|
users_to_remove.append(user.id) |
|
logging.info(f"User {user.id} disconnected") |
|
for user_id in users_to_remove: |
|
self.connected_users.pop(user_id) |
|
await asyncio.sleep(1) |
|
|
|
def on_user_ping(self, user_id: str): |
|
if user_id not in self.connected_users: |
|
self.connected_users[user_id] = User(id=user_id) |
|
else: |
|
self.connected_users[user_id].last_ping = datetime.datetime.now() |
|
|
|
def set_user_name(self, user_id: str, name: str): |
|
if user_id not in self.connected_users: |
|
self.connected_users[user_id] = User(id=user_id) |
|
self.connected_users[user_id].name = name |
|
|
|
async def _reset_game_after_timeout(self): |
|
while True: |
|
if self.game_running: |
|
last_time_game_running = datetime.datetime.now() |
|
else: |
|
if self.agent_status != AgentStatus.IDLE and datetime.datetime.now() - last_time_game_running > datetime.timedelta(seconds=GAME_TIMEOUT): |
|
self.reset_game() |
|
await asyncio.sleep(1) |
|
|
|
|
|
def reset_game(self): |
|
self.agent_status = AgentStatus.IDLE |
|
self.current_draw_input = None |
|
self.drawing_image = None |
|
self.image_draw = None |
|
self.generated_source_image = None |
|
self.generated_source_image_url = None |
|
self.displayed_item_to_guess = None |
|
self.is_drawing_guessed = False |
|
self.guesser_id = None |
|
self.agent_history = [] |
|
self.checkpointer.delete_thread("1") |
|
|