#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ MatrixGame Websocket Gaming Server This script implements a websocket server for the MatrixGame project, allowing real-time streaming of game frames based on player inputs. """ import asyncio import json import logging import os import pathlib import time import uuid import base64 import argparse from typing import Dict, List, Any, Optional from aiohttp import web, WSMsgType # Import the game engine from engine import MatrixGameEngine from utils import logger, parse_model_args, setup_gpu_environment class GameSession: """ Represents a user's gaming session. Each WebSocket connection gets its own session with separate queues. """ def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager): self.user_id = user_id self.ws = ws self.game_manager = game_manager # Create action queue for this user session self.action_queue = asyncio.Queue() # Session creation time self.created_at = time.time() self.last_activity = time.time() # Game state self.current_scene = "forest" # Default scene self.is_streaming = False self.stream_task = None # Current input state self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack self.mouse_state = [0, 0] # x, y self.background_tasks = [] async def start(self): """Start all the queue processors for this session""" self.background_tasks = [ asyncio.create_task(self._process_action_queue()), ] logger.info(f"Started game session for user {self.user_id}") async def stop(self): """Stop all background tasks for this session""" # Stop streaming if active if self.is_streaming and self.stream_task: self.is_streaming = False self.stream_task.cancel() try: await self.stream_task except asyncio.CancelledError: pass # Cancel other background tasks for task in self.background_tasks: task.cancel() try: # Wait for tasks to complete cancellation await asyncio.gather(*self.background_tasks, return_exceptions=True) except asyncio.CancelledError: pass logger.info(f"Stopped game session for user {self.user_id}") async def _process_action_queue(self): """Process game actions from the queue""" while True: data = await self.action_queue.get() try: action_type = data.get('action') if action_type == 'start_stream': result = await self._handle_start_stream(data) elif action_type == 'stop_stream': result = await self._handle_stop_stream(data) elif action_type == 'keyboard_input': result = await self._handle_keyboard_input(data) elif action_type == 'mouse_input': result = await self._handle_mouse_input(data) elif action_type == 'change_scene': result = await self._handle_scene_change(data) else: result = { 'action': action_type, 'requestId': data.get('requestId'), 'success': False, 'error': f'Unknown action: {action_type}' } # Send response back to the client await self.ws.send_json(result) # Update last activity time self.last_activity = time.time() except Exception as e: logger.error(f"Error processing action for user {self.user_id}: {str(e)}") try: await self.ws.send_json({ 'action': data.get('action'), 'requestId': data.get('requestId', 'unknown'), 'success': False, 'error': f'Error processing action: {str(e)}' }) except Exception as send_error: logger.error(f"Error sending error response: {send_error}") finally: self.action_queue.task_done() async def _handle_start_stream(self, data: Dict) -> Dict: """Handle request to start streaming frames""" if self.is_streaming: return { 'action': 'start_stream', 'requestId': data.get('requestId'), 'success': False, 'error': 'Stream already active' } fps = data.get('fps', 16) self.is_streaming = True self.stream_task = asyncio.create_task(self._stream_frames(fps)) return { 'action': 'start_stream', 'requestId': data.get('requestId'), 'success': True, 'message': f'Streaming started at {fps} FPS' } async def _handle_stop_stream(self, data: Dict) -> Dict: """Handle request to stop streaming frames""" if not self.is_streaming: return { 'action': 'stop_stream', 'requestId': data.get('requestId'), 'success': False, 'error': 'No active stream to stop' } self.is_streaming = False if self.stream_task: self.stream_task.cancel() try: await self.stream_task except asyncio.CancelledError: pass self.stream_task = None return { 'action': 'stop_stream', 'requestId': data.get('requestId'), 'success': True, 'message': 'Streaming stopped' } async def _handle_keyboard_input(self, data: Dict) -> Dict: """Handle keyboard input from client""" key = data.get('key', '') pressed = data.get('pressed', False) # Map key to keyboard state index key_map = { 'w': 0, 'forward': 0, 's': 1, 'back': 1, 'backward': 1, 'a': 2, 'left': 2, 'd': 3, 'right': 3, 'space': 4, 'jump': 4, 'shift': 5, 'attack': 5, 'ctrl': 5 } if key.lower() in key_map: key_idx = key_map[key.lower()] self.keyboard_state[key_idx] = 1 if pressed else 0 return { 'action': 'keyboard_input', 'requestId': data.get('requestId'), 'success': True, 'keyboardState': self.keyboard_state } async def _handle_mouse_input(self, data: Dict) -> Dict: """Handle mouse movement/input from client""" mouse_x = data.get('x', 0) mouse_y = data.get('y', 0) # Update mouse state, normalize values between -1 and 1 self.mouse_state = [float(mouse_x), float(mouse_y)] return { 'action': 'mouse_input', 'requestId': data.get('requestId'), 'success': True, 'mouseState': self.mouse_state } async def _handle_scene_change(self, data: Dict) -> Dict: """Handle scene change requests""" scene_name = data.get('scene', 'forest') valid_scenes = self.game_manager.valid_scenes if scene_name not in valid_scenes: return { 'action': 'change_scene', 'requestId': data.get('requestId'), 'success': False, 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}' } self.current_scene = scene_name return { 'action': 'change_scene', 'requestId': data.get('requestId'), 'success': True, 'scene': scene_name } async def _stream_frames(self, fps: int): """Stream frames to the client at the specified FPS""" frame_interval = 1.0 / fps # Time between frames in seconds try: while self.is_streaming: start_time = time.time() # Generate frame based on current keyboard and mouse state keyboard_condition = [self.keyboard_state] mouse_condition = [self.mouse_state] # Use the engine to generate the next frame frame_bytes = self.game_manager.engine.generate_frame( self.current_scene, keyboard_condition, mouse_condition ) # Encode as base64 for sending in JSON frame_base64 = base64.b64encode(frame_bytes).decode('utf-8') # Send frame to client await self.ws.send_json({ 'action': 'frame', 'frameData': frame_base64, 'timestamp': time.time() }) # Calculate sleep time to maintain FPS elapsed = time.time() - start_time sleep_time = max(0, frame_interval - elapsed) await asyncio.sleep(sleep_time) except asyncio.CancelledError: logger.info(f"Frame streaming cancelled for user {self.user_id}") except Exception as e: logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}") if self.ws.closed: logger.info(f"WebSocket closed for user {self.user_id}") return # Notify client of error try: await self.ws.send_json({ 'action': 'frame_error', 'error': f'Streaming error: {str(e)}' }) except: pass # Stop streaming self.is_streaming = False class GameManager: """ Manages all active gaming sessions and shared resources. """ def __init__(self, args: argparse.Namespace): self.sessions = {} self.session_lock = asyncio.Lock() # Initialize game engine self.engine = MatrixGameEngine(args) # Load valid scenes from engine self.valid_scenes = self.engine.get_valid_scenes() async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession: """Create a new game session""" async with self.session_lock: # Create a new session for this user session = GameSession(user_id, ws, self) await session.start() self.sessions[user_id] = session return session async def delete_session(self, user_id: str) -> None: """Delete a game session and clean up resources""" async with self.session_lock: if user_id in self.sessions: session = self.sessions[user_id] await session.stop() del self.sessions[user_id] logger.info(f"Deleted game session for user {user_id}") def get_session(self, user_id: str) -> Optional[GameSession]: """Get a game session if it exists""" return self.sessions.get(user_id) async def close_all_sessions(self) -> None: """Close all active sessions (used during shutdown)""" async with self.session_lock: for user_id, session in list(self.sessions.items()): await session.stop() self.sessions.clear() logger.info("Closed all active game sessions") @property def session_count(self) -> int: """Get the number of active sessions""" return len(self.sessions) def get_session_stats(self) -> Dict: """Get statistics about active sessions""" stats = { 'total_sessions': len(self.sessions), 'active_scenes': {}, 'streaming_sessions': 0 } # Count sessions by scene and streaming status for session in self.sessions.values(): scene = session.current_scene stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1 if session.is_streaming: stats['streaming_sessions'] += 1 return stats # Create global game manager game_manager = None async def status_handler(request: web.Request) -> web.Response: """Handler for API status endpoint""" # Get session statistics session_stats = game_manager.get_session_stats() return web.json_response({ 'product': 'MatrixGame WebSocket Server', 'version': '1.0.0', 'active_sessions': session_stats, 'available_scenes': game_manager.valid_scenes }) async def root_handler(request: web.Request) -> web.Response: """Handler for serving the client at the root path""" client_path = pathlib.Path(__file__).parent / 'client' / 'index.html' with open(client_path, 'r') as file: html_content = file.read() return web.Response(text=html_content, content_type='text/html') async def websocket_handler(request: web.Request) -> web.WebSocketResponse: """Handle WebSocket connections with robust error handling""" logger.info(f"WebSocket connection attempt - PATH: {request.path}, QUERY: {request.query_string}") # Log request headers at debug level only (could contain sensitive information) logger.debug(f"WebSocket request headers: {dict(request.headers)}") # Prepare a WebSocket response with appropriate settings ws = web.WebSocketResponse( max_msg_size=1024*1024*10, # 10MB max message size timeout=60.0, heartbeat=30.0 # Add heartbeat to keep connection alive ) # Check if WebSocket protocol is supported if not ws.can_prepare(request): logger.error("Cannot prepare WebSocket: WebSocket protocol not supported") return web.Response(status=400, text="WebSocket protocol not supported") try: logger.info("Preparing WebSocket connection...") await ws.prepare(request) # Generate a unique user ID for this connection user_id = str(uuid.uuid4()) # Get client IP address peername = request.transport.get_extra_info('peername') if peername is not None: client_ip = peername[0] else: client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip() # Log connection success logger.info(f"Client {user_id} connecting from IP: {client_ip} - WebSocket connection established") # Mark that the session is established is_session_created = False try: # Store the user ID in the websocket for easy access ws.user_id = user_id # Create a new session for this user logger.info(f"Creating game session for user {user_id}") user_session = await game_manager.create_session(user_id, ws) is_session_created = True logger.info(f"Game session created for user {user_id}") except Exception as session_error: logger.error(f"Error creating game session: {str(session_error)}", exc_info=True) if not ws.closed: await ws.close(code=1011, message=f"Server error: {str(session_error)}".encode()) if is_session_created: await game_manager.delete_session(user_id) return ws except Exception as e: logger.error(f"Error establishing WebSocket connection: {str(e)}", exc_info=True) if not ws.closed and ws.prepared: await ws.close(code=1011, message=f"Server error: {str(e)}".encode()) return ws # Send initial welcome message try: await ws.send_json({ 'action': 'welcome', 'userId': user_id, 'message': 'Welcome to the MatrixGame WebSocket server!', 'scenes': game_manager.valid_scenes }) logger.info(f"Sent welcome message to user {user_id}") except Exception as welcome_error: logger.error(f"Error sending welcome message: {str(welcome_error)}") if not ws.closed: await ws.close(code=1011, message=b"Failed to send welcome message") await game_manager.delete_session(user_id) return ws try: async for msg in ws: if msg.type == WSMsgType.TEXT: try: data = json.loads(msg.data) action = data.get('action') logger.debug(f"Received {action} message from user {user_id}") if action == 'ping': # Respond to ping immediately await ws.send_json({ 'action': 'pong', 'requestId': data.get('requestId'), 'timestamp': time.time() }) else: # Route game actions to the session's action queue await user_session.action_queue.put(data) except json.JSONDecodeError: logger.error(f"Invalid JSON from user {user_id}: {msg.data}") if not ws.closed: await ws.send_json({ 'error': 'Invalid JSON message', 'success': False }) except Exception as e: logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}") if not ws.closed: await ws.send_json({ 'action': data.get('action') if 'data' in locals() else 'unknown', 'success': False, 'error': f'Error processing message: {str(e)}' }) elif msg.type == WSMsgType.ERROR: logger.error(f"WebSocket error for user {user_id}: {ws.exception()}") break elif msg.type == WSMsgType.CLOSE: logger.info(f"WebSocket close received for user {user_id} (code: {msg.data}, message: {msg.extra})") break elif msg.type == WSMsgType.CLOSING: logger.info(f"WebSocket closing for user {user_id}") break elif msg.type == WSMsgType.CLOSED: logger.info(f"WebSocket already closed for user {user_id}") break except Exception as ws_error: logger.error(f"Unexpected WebSocket error for user {user_id}: {str(ws_error)}", exc_info=True) finally: # Cleanup session try: logger.info(f"Cleaning up session for user {user_id}") await game_manager.delete_session(user_id) logger.info(f"Connection closed for user {user_id}") except Exception as cleanup_error: logger.error(f"Error during session cleanup for user {user_id}: {str(cleanup_error)}") return ws async def init_app(args, base_path="") -> web.Application: """Initialize the web application""" global game_manager # Initialize game manager with command line args game_manager = GameManager(args) app = web.Application( client_max_size=1024**2*10 # 10MB max size ) # Add cleanup logic async def cleanup(app): logger.info("Shutting down server, closing all sessions...") await game_manager.close_all_sessions() app.on_shutdown.append(cleanup) # Add routes with CORS headers for WebSockets # Configure CORS for all routes @web.middleware async def cors_middleware(request, handler): if request.method == 'OPTIONS': # Handle preflight requests resp = web.Response() resp.headers['Access-Control-Allow-Origin'] = '*' resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With' return resp # Normal request, call the handler resp = await handler(request) # Add CORS headers to the response resp.headers['Access-Control-Allow-Origin'] = '*' resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With' return resp app.middlewares.append(cors_middleware) # Add a debug endpoint to help diagnose WebSocket issues async def debug_handler(request): client_ip = request.remote headers = dict(request.headers) server_host = request.host debug_info = { "client_ip": client_ip, "server_host": server_host, "headers": headers, "request_path": request.path, "server_time": time.time(), "base_path": base_path, "websocket_route": f"{base_path}/ws", "all_routes": [route.name for route in app.router.routes() if route.name], "server_info": { "active_sessions": game_manager.session_count, "available_scenes": game_manager.valid_scenes } } return web.json_response(debug_info) # Set up routes with the base_path # Add multiple WebSocket routes to ensure compatibility logger.info(f"Setting up WebSocket route at {base_path}/ws") app.router.add_get(f'{base_path}/ws', websocket_handler, name='ws_handler') # Also add WebSocket route at the root for Hugging Face compatibility if base_path: logger.info(f"Adding additional WebSocket route at /ws") app.router.add_get('/ws', websocket_handler, name='ws_root_handler') # Add routes for API and debug endpoints app.router.add_get(f'{base_path}/api/status', status_handler, name='status_handler') app.router.add_get(f'{base_path}/api/debug', debug_handler, name='debug_handler') # Serve the client at both the base path and root path for compatibility app.router.add_get(f'{base_path}/', root_handler, name='root_handler') # Always serve at the root path for Hugging Face Spaces compatibility if base_path: app.router.add_get('/', root_handler, name='root_handler_no_base') # Set up static file serving for the client assets app.router.add_static(f'{base_path}/assets', pathlib.Path(__file__).parent / 'client', name='static_handler') # Add static file serving at root for compatibility if base_path: app.router.add_static('/assets', pathlib.Path(__file__).parent / 'client', name='static_handler_no_base') return app def parse_args() -> argparse.Namespace: """Parse server-specific command line arguments""" parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to") parser.add_argument("--port", type=int, default=8080, help="Port to listen on") parser.add_argument("--path", type=str, default="", help="Base path for the server (for proxy setups)") # Parse server args first server_args, remaining_args = parser.parse_known_args() # Parse model args and combine model_args = parse_model_args() # Combine all args combined_args = argparse.Namespace(**vars(server_args), **vars(model_args)) return combined_args if __name__ == '__main__': # Configure GPU environment setup_gpu_environment() # Parse command line arguments args = parse_args() # Initialize app loop = asyncio.get_event_loop() app = loop.run_until_complete(init_app(args, base_path=args.path)) # Start server logger.info(f"Starting MatrixGame WebSocket Server at {args.host}:{args.port}") web.run_app(app, host=args.host, port=args.port)