jbilcke-hf's picture
Upload 5 files
4315c88 verified
raw
history blame
21.7 kB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MatrixGame Websocket Gaming Server
This script implements a websocket server for the MatrixGame project,
allowing real-time streaming of game frames based on player inputs.
"""
import asyncio
import json
import logging
import os
import pathlib
import time
import uuid
import io
import base64
from typing import Dict, List, Any, Optional
import argparse
import torch
import numpy as np
from PIL import Image
import cv2
from aiohttp import web, WSMsgType
from condtions import Bench_actions_76
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class FrameGenerator:
"""
Simplified frame generator for the game.
In production, this would use the MatrixGame model.
"""
def __init__(self):
self.frame_width = 640
self.frame_height = 360
self.fps = 16
self.frame_count = 0
self.scenes = {
'forest': self._load_scene_frames('forest'),
'desert': self._load_scene_frames('desert'),
'beach': self._load_scene_frames('beach'),
'hills': self._load_scene_frames('hills'),
'river': self._load_scene_frames('river'),
'icy': self._load_scene_frames('icy'),
'mushroom': self._load_scene_frames('mushroom'),
'plain': self._load_scene_frames('plain')
}
def _load_scene_frames(self, scene_name):
"""Load initial frames for a scene from asset directory"""
frames = []
scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}"
if os.path.exists(scene_dir):
image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')])
for img_file in image_files:
try:
img_path = os.path.join(scene_dir, img_file)
img = Image.open(img_path).convert("RGB")
img = img.resize((self.frame_width, self.frame_height))
frames.append(np.array(img))
except Exception as e:
logger.error(f"Error loading image {img_file}: {str(e)}")
# If no frames were loaded, create a default colored frame with text
if not frames:
frame = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100
# Add scene name as text
cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
frames.append(frame)
return frames
def get_next_frame(self, scene_name, keyboard_condition=None, mouse_condition=None):
"""
Generate the next frame based on current conditions.
Args:
scene_name: Name of the current scene
keyboard_condition: Keyboard input state
mouse_condition: Mouse input state
Returns:
JPEG bytes of the frame
"""
scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
# In a real implementation, this would use the MatrixGame model to generate frames
# based on the keyboard_condition and mouse_condition
# For the demo, just cycle through the pre-loaded frames
frame_idx = self.frame_count % len(scene_frames)
frame = scene_frames[frame_idx].copy()
self.frame_count += 1
# If we have keyboard/mouse conditions, visualize them on the frame
if keyboard_condition:
# Visualize keyboard inputs (simple example)
keys = ["W", "S", "A", "D", "JUMP", "ATTACK"]
for i, key_pressed in enumerate(keyboard_condition[0]):
color = (0, 255, 0) if key_pressed else (100, 100, 100)
cv2.putText(frame, keys[i], (20 + i*100, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
if mouse_condition:
# Visualize mouse movement (simple example)
mouse_x, mouse_y = mouse_condition[0]
# Scale mouse values for visualization
offset_x = int(mouse_x * 100)
offset_y = int(mouse_y * 100)
center_x, center_y = self.frame_width // 2, self.frame_height // 2
cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1)
cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}",
(self.frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
# Convert frame to JPEG
success, buffer = cv2.imencode('.jpg', frame)
if not success:
logger.error("Failed to encode frame as JPEG")
# Return a blank frame
blank = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100
success, buffer = cv2.imencode('.jpg', blank)
return buffer.tobytes()
class GameSession:
"""
Represents a user's gaming session.
Each WebSocket connection gets its own session with separate queues.
"""
def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager):
self.user_id = user_id
self.ws = ws
self.game_manager = game_manager
# Create action queue for this user session
self.action_queue = asyncio.Queue()
# Session creation time
self.created_at = time.time()
self.last_activity = time.time()
# Game state
self.current_scene = "forest" # Default scene
self.is_streaming = False
self.stream_task = None
# Current input state
self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack
self.mouse_state = [0, 0] # x, y
self.background_tasks = []
async def start(self):
"""Start all the queue processors for this session"""
self.background_tasks = [
asyncio.create_task(self._process_action_queue()),
]
logger.info(f"Started game session for user {self.user_id}")
async def stop(self):
"""Stop all background tasks for this session"""
# Stop streaming if active
if self.is_streaming and self.stream_task:
self.is_streaming = False
self.stream_task.cancel()
try:
await self.stream_task
except asyncio.CancelledError:
pass
# Cancel other background tasks
for task in self.background_tasks:
task.cancel()
try:
# Wait for tasks to complete cancellation
await asyncio.gather(*self.background_tasks, return_exceptions=True)
except asyncio.CancelledError:
pass
logger.info(f"Stopped game session for user {self.user_id}")
async def _process_action_queue(self):
"""Process game actions from the queue"""
while True:
data = await self.action_queue.get()
try:
action_type = data.get('action')
if action_type == 'start_stream':
result = await self._handle_start_stream(data)
elif action_type == 'stop_stream':
result = await self._handle_stop_stream(data)
elif action_type == 'keyboard_input':
result = await self._handle_keyboard_input(data)
elif action_type == 'mouse_input':
result = await self._handle_mouse_input(data)
elif action_type == 'change_scene':
result = await self._handle_scene_change(data)
else:
result = {
'action': action_type,
'requestId': data.get('requestId'),
'success': False,
'error': f'Unknown action: {action_type}'
}
# Send response back to the client
await self.ws.send_json(result)
# Update last activity time
self.last_activity = time.time()
except Exception as e:
logger.error(f"Error processing action for user {self.user_id}: {str(e)}")
try:
await self.ws.send_json({
'action': data.get('action'),
'requestId': data.get('requestId', 'unknown'),
'success': False,
'error': f'Error processing action: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
self.action_queue.task_done()
async def _handle_start_stream(self, data: Dict) -> Dict:
"""Handle request to start streaming frames"""
if self.is_streaming:
return {
'action': 'start_stream',
'requestId': data.get('requestId'),
'success': False,
'error': 'Stream already active'
}
fps = data.get('fps', 16)
self.is_streaming = True
self.stream_task = asyncio.create_task(self._stream_frames(fps))
return {
'action': 'start_stream',
'requestId': data.get('requestId'),
'success': True,
'message': f'Streaming started at {fps} FPS'
}
async def _handle_stop_stream(self, data: Dict) -> Dict:
"""Handle request to stop streaming frames"""
if not self.is_streaming:
return {
'action': 'stop_stream',
'requestId': data.get('requestId'),
'success': False,
'error': 'No active stream to stop'
}
self.is_streaming = False
if self.stream_task:
self.stream_task.cancel()
try:
await self.stream_task
except asyncio.CancelledError:
pass
self.stream_task = None
return {
'action': 'stop_stream',
'requestId': data.get('requestId'),
'success': True,
'message': 'Streaming stopped'
}
async def _handle_keyboard_input(self, data: Dict) -> Dict:
"""Handle keyboard input from client"""
key = data.get('key', '')
pressed = data.get('pressed', False)
# Map key to keyboard state index
key_map = {
'w': 0, 'forward': 0,
's': 1, 'back': 1, 'backward': 1,
'a': 2, 'left': 2,
'd': 3, 'right': 3,
'space': 4, 'jump': 4,
'shift': 5, 'attack': 5, 'ctrl': 5
}
if key.lower() in key_map:
key_idx = key_map[key.lower()]
self.keyboard_state[key_idx] = 1 if pressed else 0
return {
'action': 'keyboard_input',
'requestId': data.get('requestId'),
'success': True,
'keyboardState': self.keyboard_state
}
async def _handle_mouse_input(self, data: Dict) -> Dict:
"""Handle mouse movement/input from client"""
mouse_x = data.get('x', 0)
mouse_y = data.get('y', 0)
# Update mouse state, normalize values between -1 and 1
self.mouse_state = [float(mouse_x), float(mouse_y)]
return {
'action': 'mouse_input',
'requestId': data.get('requestId'),
'success': True,
'mouseState': self.mouse_state
}
async def _handle_scene_change(self, data: Dict) -> Dict:
"""Handle scene change requests"""
scene_name = data.get('scene', 'forest')
valid_scenes = ['forest', 'desert', 'beach', 'hills', 'river', 'icy', 'mushroom', 'plain']
if scene_name not in valid_scenes:
return {
'action': 'change_scene',
'requestId': data.get('requestId'),
'success': False,
'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}'
}
self.current_scene = scene_name
return {
'action': 'change_scene',
'requestId': data.get('requestId'),
'success': True,
'scene': scene_name
}
async def _stream_frames(self, fps: int):
"""Stream frames to the client at the specified FPS"""
frame_interval = 1.0 / fps # Time between frames in seconds
try:
while self.is_streaming:
start_time = time.time()
# Generate frame based on current keyboard and mouse state
keyboard_condition = [self.keyboard_state]
mouse_condition = [self.mouse_state]
frame_bytes = self.game_manager.frame_generator.get_next_frame(
self.current_scene, keyboard_condition, mouse_condition
)
# Encode as base64 for sending in JSON
frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
# Send frame to client
await self.ws.send_json({
'action': 'frame',
'frameData': frame_base64,
'timestamp': time.time()
})
# Calculate sleep time to maintain FPS
elapsed = time.time() - start_time
sleep_time = max(0, frame_interval - elapsed)
await asyncio.sleep(sleep_time)
except asyncio.CancelledError:
logger.info(f"Frame streaming cancelled for user {self.user_id}")
except Exception as e:
logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}")
if self.ws.closed:
logger.info(f"WebSocket closed for user {self.user_id}")
return
# Notify client of error
try:
await self.ws.send_json({
'action': 'frame_error',
'error': f'Streaming error: {str(e)}'
})
except:
pass
# Stop streaming
self.is_streaming = False
class GameManager:
"""
Manages all active gaming sessions and shared resources.
"""
def __init__(self):
self.sessions = {}
self.session_lock = asyncio.Lock()
# Initialize frame generator
self.frame_generator = FrameGenerator()
# Load valid scenes from FrameGenerator
self.valid_scenes = list(self.frame_generator.scenes.keys())
async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession:
"""Create a new game session"""
async with self.session_lock:
# Create a new session for this user
session = GameSession(user_id, ws, self)
await session.start()
self.sessions[user_id] = session
return session
async def delete_session(self, user_id: str) -> None:
"""Delete a game session and clean up resources"""
async with self.session_lock:
if user_id in self.sessions:
session = self.sessions[user_id]
await session.stop()
del self.sessions[user_id]
logger.info(f"Deleted game session for user {user_id}")
def get_session(self, user_id: str) -> Optional[GameSession]:
"""Get a game session if it exists"""
return self.sessions.get(user_id)
async def close_all_sessions(self) -> None:
"""Close all active sessions (used during shutdown)"""
async with self.session_lock:
for user_id, session in list(self.sessions.items()):
await session.stop()
self.sessions.clear()
logger.info("Closed all active game sessions")
@property
def session_count(self) -> int:
"""Get the number of active sessions"""
return len(self.sessions)
def get_session_stats(self) -> Dict:
"""Get statistics about active sessions"""
stats = {
'total_sessions': len(self.sessions),
'active_scenes': {},
'streaming_sessions': 0
}
# Count sessions by scene and streaming status
for session in self.sessions.values():
scene = session.current_scene
stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1
if session.is_streaming:
stats['streaming_sessions'] += 1
return stats
# Create global game manager
game_manager = GameManager()
async def status_handler(request: web.Request) -> web.Response:
"""Handler for API status endpoint"""
# Get session statistics
session_stats = game_manager.get_session_stats()
return web.json_response({
'product': 'MatrixGame WebSocket Server',
'version': '1.0.0',
'active_sessions': session_stats,
'available_scenes': game_manager.valid_scenes
})
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse(
max_msg_size=1024*1024*10, # 10MB max message size
timeout=60.0
)
await ws.prepare(request)
# Generate a unique user ID for this connection
user_id = str(uuid.uuid4())
# Get client IP address
peername = request.transport.get_extra_info('peername')
if peername is not None:
client_ip = peername[0]
else:
client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
logger.info(f"Client {user_id} connecting from IP: {client_ip}")
# Store the user ID in the websocket for easy access
ws.user_id = user_id
# Create a new session for this user
user_session = await game_manager.create_session(user_id, ws)
# Send initial welcome message
await ws.send_json({
'action': 'welcome',
'userId': user_id,
'message': 'Welcome to the MatrixGame WebSocket server!',
'scenes': game_manager.valid_scenes
})
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
action = data.get('action')
if action == 'ping':
# Respond to ping immediately
await ws.send_json({
'action': 'pong',
'requestId': data.get('requestId'),
'timestamp': time.time()
})
else:
# Route game actions to the session's action queue
await user_session.action_queue.put(data)
except json.JSONDecodeError:
logger.error(f"Invalid JSON from user {user_id}: {msg.data}")
await ws.send_json({
'error': 'Invalid JSON message',
'success': False
})
except Exception as e:
logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
await ws.send_json({
'action': data.get('action') if 'data' in locals() else 'unknown',
'success': False,
'error': f'Error processing message: {str(e)}'
})
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
finally:
# Cleanup session
await game_manager.delete_session(user_id)
logger.info(f"Connection closed for user {user_id}")
return ws
async def init_app() -> web.Application:
app = web.Application(
client_max_size=1024**2*10 # 10MB max size
)
# Add cleanup logic
async def cleanup(app):
logger.info("Shutting down server, closing all sessions...")
await game_manager.close_all_sessions()
app.on_shutdown.append(cleanup)
# Add routes
app.router.add_get('/ws', websocket_handler)
app.router.add_get('/api/status', status_handler)
# Set up static file serving for the client demo
app.router.add_static('/client', path=pathlib.Path(__file__).parent / 'client')
return app
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to")
parser.add_argument("--port", type=int, default=8080, help="Port to listen on")
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
app = asyncio.run(init_app())
web.run_app(app, host=args.host, port=args.port)