|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
MatrixGame Websocket Gaming Server |
|
|
|
|
|
This script implements a websocket server for the MatrixGame project, |
|
|
allowing real-time streaming of game frames based on player inputs. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import pathlib |
|
|
import time |
|
|
import uuid |
|
|
import io |
|
|
import base64 |
|
|
from typing import Dict, List, Any, Optional |
|
|
import argparse |
|
|
import torch |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import cv2 |
|
|
from aiohttp import web, WSMsgType |
|
|
from condtions import Bench_actions_76 |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class FrameGenerator: |
|
|
""" |
|
|
Simplified frame generator for the game. |
|
|
In production, this would use the MatrixGame model. |
|
|
""" |
|
|
def __init__(self): |
|
|
self.frame_width = 640 |
|
|
self.frame_height = 360 |
|
|
self.fps = 16 |
|
|
self.frame_count = 0 |
|
|
self.scenes = { |
|
|
'forest': self._load_scene_frames('forest'), |
|
|
'desert': self._load_scene_frames('desert'), |
|
|
'beach': self._load_scene_frames('beach'), |
|
|
'hills': self._load_scene_frames('hills'), |
|
|
'river': self._load_scene_frames('river'), |
|
|
'icy': self._load_scene_frames('icy'), |
|
|
'mushroom': self._load_scene_frames('mushroom'), |
|
|
'plain': self._load_scene_frames('plain') |
|
|
} |
|
|
|
|
|
def _load_scene_frames(self, scene_name): |
|
|
"""Load initial frames for a scene from asset directory""" |
|
|
frames = [] |
|
|
scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}" |
|
|
|
|
|
if os.path.exists(scene_dir): |
|
|
image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')]) |
|
|
for img_file in image_files: |
|
|
try: |
|
|
img_path = os.path.join(scene_dir, img_file) |
|
|
img = Image.open(img_path).convert("RGB") |
|
|
img = img.resize((self.frame_width, self.frame_height)) |
|
|
frames.append(np.array(img)) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading image {img_file}: {str(e)}") |
|
|
|
|
|
|
|
|
if not frames: |
|
|
frame = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100 |
|
|
|
|
|
cv2.putText(frame, f"Scene: {scene_name}", (50, 180), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) |
|
|
frames.append(frame) |
|
|
|
|
|
return frames |
|
|
|
|
|
def get_next_frame(self, scene_name, keyboard_condition=None, mouse_condition=None): |
|
|
""" |
|
|
Generate the next frame based on current conditions. |
|
|
|
|
|
Args: |
|
|
scene_name: Name of the current scene |
|
|
keyboard_condition: Keyboard input state |
|
|
mouse_condition: Mouse input state |
|
|
|
|
|
Returns: |
|
|
JPEG bytes of the frame |
|
|
""" |
|
|
scene_frames = self.scenes.get(scene_name, self.scenes['forest']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
frame_idx = self.frame_count % len(scene_frames) |
|
|
frame = scene_frames[frame_idx].copy() |
|
|
self.frame_count += 1 |
|
|
|
|
|
|
|
|
if keyboard_condition: |
|
|
|
|
|
keys = ["W", "S", "A", "D", "JUMP", "ATTACK"] |
|
|
for i, key_pressed in enumerate(keyboard_condition[0]): |
|
|
color = (0, 255, 0) if key_pressed else (100, 100, 100) |
|
|
cv2.putText(frame, keys[i], (20 + i*100, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) |
|
|
|
|
|
if mouse_condition: |
|
|
|
|
|
mouse_x, mouse_y = mouse_condition[0] |
|
|
|
|
|
offset_x = int(mouse_x * 100) |
|
|
offset_y = int(mouse_y * 100) |
|
|
center_x, center_y = self.frame_width // 2, self.frame_height // 2 |
|
|
cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1) |
|
|
cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}", |
|
|
(self.frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) |
|
|
|
|
|
|
|
|
success, buffer = cv2.imencode('.jpg', frame) |
|
|
if not success: |
|
|
logger.error("Failed to encode frame as JPEG") |
|
|
|
|
|
blank = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100 |
|
|
success, buffer = cv2.imencode('.jpg', blank) |
|
|
|
|
|
return buffer.tobytes() |
|
|
|
|
|
class GameSession: |
|
|
""" |
|
|
Represents a user's gaming session. |
|
|
Each WebSocket connection gets its own session with separate queues. |
|
|
""" |
|
|
def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager): |
|
|
self.user_id = user_id |
|
|
self.ws = ws |
|
|
self.game_manager = game_manager |
|
|
|
|
|
|
|
|
self.action_queue = asyncio.Queue() |
|
|
|
|
|
|
|
|
self.created_at = time.time() |
|
|
self.last_activity = time.time() |
|
|
|
|
|
|
|
|
self.current_scene = "forest" |
|
|
self.is_streaming = False |
|
|
self.stream_task = None |
|
|
|
|
|
|
|
|
self.keyboard_state = [0, 0, 0, 0, 0, 0] |
|
|
self.mouse_state = [0, 0] |
|
|
|
|
|
self.background_tasks = [] |
|
|
|
|
|
async def start(self): |
|
|
"""Start all the queue processors for this session""" |
|
|
self.background_tasks = [ |
|
|
asyncio.create_task(self._process_action_queue()), |
|
|
] |
|
|
logger.info(f"Started game session for user {self.user_id}") |
|
|
|
|
|
async def stop(self): |
|
|
"""Stop all background tasks for this session""" |
|
|
|
|
|
if self.is_streaming and self.stream_task: |
|
|
self.is_streaming = False |
|
|
self.stream_task.cancel() |
|
|
try: |
|
|
await self.stream_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
|
|
|
for task in self.background_tasks: |
|
|
task.cancel() |
|
|
|
|
|
try: |
|
|
|
|
|
await asyncio.gather(*self.background_tasks, return_exceptions=True) |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
logger.info(f"Stopped game session for user {self.user_id}") |
|
|
|
|
|
async def _process_action_queue(self): |
|
|
"""Process game actions from the queue""" |
|
|
while True: |
|
|
data = await self.action_queue.get() |
|
|
try: |
|
|
action_type = data.get('action') |
|
|
|
|
|
if action_type == 'start_stream': |
|
|
result = await self._handle_start_stream(data) |
|
|
elif action_type == 'stop_stream': |
|
|
result = await self._handle_stop_stream(data) |
|
|
elif action_type == 'keyboard_input': |
|
|
result = await self._handle_keyboard_input(data) |
|
|
elif action_type == 'mouse_input': |
|
|
result = await self._handle_mouse_input(data) |
|
|
elif action_type == 'change_scene': |
|
|
result = await self._handle_scene_change(data) |
|
|
else: |
|
|
result = { |
|
|
'action': action_type, |
|
|
'requestId': data.get('requestId'), |
|
|
'success': False, |
|
|
'error': f'Unknown action: {action_type}' |
|
|
} |
|
|
|
|
|
|
|
|
await self.ws.send_json(result) |
|
|
|
|
|
|
|
|
self.last_activity = time.time() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error processing action for user {self.user_id}: {str(e)}") |
|
|
try: |
|
|
await self.ws.send_json({ |
|
|
'action': data.get('action'), |
|
|
'requestId': data.get('requestId', 'unknown'), |
|
|
'success': False, |
|
|
'error': f'Error processing action: {str(e)}' |
|
|
}) |
|
|
except Exception as send_error: |
|
|
logger.error(f"Error sending error response: {send_error}") |
|
|
finally: |
|
|
self.action_queue.task_done() |
|
|
|
|
|
async def _handle_start_stream(self, data: Dict) -> Dict: |
|
|
"""Handle request to start streaming frames""" |
|
|
if self.is_streaming: |
|
|
return { |
|
|
'action': 'start_stream', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': False, |
|
|
'error': 'Stream already active' |
|
|
} |
|
|
|
|
|
fps = data.get('fps', 16) |
|
|
self.is_streaming = True |
|
|
self.stream_task = asyncio.create_task(self._stream_frames(fps)) |
|
|
|
|
|
return { |
|
|
'action': 'start_stream', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': True, |
|
|
'message': f'Streaming started at {fps} FPS' |
|
|
} |
|
|
|
|
|
async def _handle_stop_stream(self, data: Dict) -> Dict: |
|
|
"""Handle request to stop streaming frames""" |
|
|
if not self.is_streaming: |
|
|
return { |
|
|
'action': 'stop_stream', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': False, |
|
|
'error': 'No active stream to stop' |
|
|
} |
|
|
|
|
|
self.is_streaming = False |
|
|
if self.stream_task: |
|
|
self.stream_task.cancel() |
|
|
try: |
|
|
await self.stream_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
self.stream_task = None |
|
|
|
|
|
return { |
|
|
'action': 'stop_stream', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': True, |
|
|
'message': 'Streaming stopped' |
|
|
} |
|
|
|
|
|
async def _handle_keyboard_input(self, data: Dict) -> Dict: |
|
|
"""Handle keyboard input from client""" |
|
|
key = data.get('key', '') |
|
|
pressed = data.get('pressed', False) |
|
|
|
|
|
|
|
|
key_map = { |
|
|
'w': 0, 'forward': 0, |
|
|
's': 1, 'back': 1, 'backward': 1, |
|
|
'a': 2, 'left': 2, |
|
|
'd': 3, 'right': 3, |
|
|
'space': 4, 'jump': 4, |
|
|
'shift': 5, 'attack': 5, 'ctrl': 5 |
|
|
} |
|
|
|
|
|
if key.lower() in key_map: |
|
|
key_idx = key_map[key.lower()] |
|
|
self.keyboard_state[key_idx] = 1 if pressed else 0 |
|
|
|
|
|
return { |
|
|
'action': 'keyboard_input', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': True, |
|
|
'keyboardState': self.keyboard_state |
|
|
} |
|
|
|
|
|
async def _handle_mouse_input(self, data: Dict) -> Dict: |
|
|
"""Handle mouse movement/input from client""" |
|
|
mouse_x = data.get('x', 0) |
|
|
mouse_y = data.get('y', 0) |
|
|
|
|
|
|
|
|
self.mouse_state = [float(mouse_x), float(mouse_y)] |
|
|
|
|
|
return { |
|
|
'action': 'mouse_input', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': True, |
|
|
'mouseState': self.mouse_state |
|
|
} |
|
|
|
|
|
async def _handle_scene_change(self, data: Dict) -> Dict: |
|
|
"""Handle scene change requests""" |
|
|
scene_name = data.get('scene', 'forest') |
|
|
valid_scenes = ['forest', 'desert', 'beach', 'hills', 'river', 'icy', 'mushroom', 'plain'] |
|
|
|
|
|
if scene_name not in valid_scenes: |
|
|
return { |
|
|
'action': 'change_scene', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': False, |
|
|
'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}' |
|
|
} |
|
|
|
|
|
self.current_scene = scene_name |
|
|
|
|
|
return { |
|
|
'action': 'change_scene', |
|
|
'requestId': data.get('requestId'), |
|
|
'success': True, |
|
|
'scene': scene_name |
|
|
} |
|
|
|
|
|
async def _stream_frames(self, fps: int): |
|
|
"""Stream frames to the client at the specified FPS""" |
|
|
frame_interval = 1.0 / fps |
|
|
|
|
|
try: |
|
|
while self.is_streaming: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
keyboard_condition = [self.keyboard_state] |
|
|
mouse_condition = [self.mouse_state] |
|
|
|
|
|
frame_bytes = self.game_manager.frame_generator.get_next_frame( |
|
|
self.current_scene, keyboard_condition, mouse_condition |
|
|
) |
|
|
|
|
|
|
|
|
frame_base64 = base64.b64encode(frame_bytes).decode('utf-8') |
|
|
|
|
|
|
|
|
await self.ws.send_json({ |
|
|
'action': 'frame', |
|
|
'frameData': frame_base64, |
|
|
'timestamp': time.time() |
|
|
}) |
|
|
|
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
sleep_time = max(0, frame_interval - elapsed) |
|
|
await asyncio.sleep(sleep_time) |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
logger.info(f"Frame streaming cancelled for user {self.user_id}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}") |
|
|
if self.ws.closed: |
|
|
logger.info(f"WebSocket closed for user {self.user_id}") |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
await self.ws.send_json({ |
|
|
'action': 'frame_error', |
|
|
'error': f'Streaming error: {str(e)}' |
|
|
}) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
self.is_streaming = False |
|
|
|
|
|
class GameManager: |
|
|
""" |
|
|
Manages all active gaming sessions and shared resources. |
|
|
""" |
|
|
def __init__(self): |
|
|
self.sessions = {} |
|
|
self.session_lock = asyncio.Lock() |
|
|
|
|
|
|
|
|
self.frame_generator = FrameGenerator() |
|
|
|
|
|
|
|
|
self.valid_scenes = list(self.frame_generator.scenes.keys()) |
|
|
|
|
|
async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession: |
|
|
"""Create a new game session""" |
|
|
async with self.session_lock: |
|
|
|
|
|
session = GameSession(user_id, ws, self) |
|
|
await session.start() |
|
|
self.sessions[user_id] = session |
|
|
return session |
|
|
|
|
|
async def delete_session(self, user_id: str) -> None: |
|
|
"""Delete a game session and clean up resources""" |
|
|
async with self.session_lock: |
|
|
if user_id in self.sessions: |
|
|
session = self.sessions[user_id] |
|
|
await session.stop() |
|
|
del self.sessions[user_id] |
|
|
logger.info(f"Deleted game session for user {user_id}") |
|
|
|
|
|
def get_session(self, user_id: str) -> Optional[GameSession]: |
|
|
"""Get a game session if it exists""" |
|
|
return self.sessions.get(user_id) |
|
|
|
|
|
async def close_all_sessions(self) -> None: |
|
|
"""Close all active sessions (used during shutdown)""" |
|
|
async with self.session_lock: |
|
|
for user_id, session in list(self.sessions.items()): |
|
|
await session.stop() |
|
|
self.sessions.clear() |
|
|
logger.info("Closed all active game sessions") |
|
|
|
|
|
@property |
|
|
def session_count(self) -> int: |
|
|
"""Get the number of active sessions""" |
|
|
return len(self.sessions) |
|
|
|
|
|
def get_session_stats(self) -> Dict: |
|
|
"""Get statistics about active sessions""" |
|
|
stats = { |
|
|
'total_sessions': len(self.sessions), |
|
|
'active_scenes': {}, |
|
|
'streaming_sessions': 0 |
|
|
} |
|
|
|
|
|
|
|
|
for session in self.sessions.values(): |
|
|
scene = session.current_scene |
|
|
stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1 |
|
|
if session.is_streaming: |
|
|
stats['streaming_sessions'] += 1 |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
game_manager = GameManager() |
|
|
|
|
|
async def status_handler(request: web.Request) -> web.Response: |
|
|
"""Handler for API status endpoint""" |
|
|
|
|
|
session_stats = game_manager.get_session_stats() |
|
|
|
|
|
return web.json_response({ |
|
|
'product': 'MatrixGame WebSocket Server', |
|
|
'version': '1.0.0', |
|
|
'active_sessions': session_stats, |
|
|
'available_scenes': game_manager.valid_scenes |
|
|
}) |
|
|
|
|
|
async def websocket_handler(request: web.Request) -> web.WebSocketResponse: |
|
|
ws = web.WebSocketResponse( |
|
|
max_msg_size=1024*1024*10, |
|
|
timeout=60.0 |
|
|
) |
|
|
|
|
|
await ws.prepare(request) |
|
|
|
|
|
|
|
|
user_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
peername = request.transport.get_extra_info('peername') |
|
|
if peername is not None: |
|
|
client_ip = peername[0] |
|
|
else: |
|
|
client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip() |
|
|
|
|
|
logger.info(f"Client {user_id} connecting from IP: {client_ip}") |
|
|
|
|
|
|
|
|
ws.user_id = user_id |
|
|
|
|
|
|
|
|
user_session = await game_manager.create_session(user_id, ws) |
|
|
|
|
|
|
|
|
await ws.send_json({ |
|
|
'action': 'welcome', |
|
|
'userId': user_id, |
|
|
'message': 'Welcome to the MatrixGame WebSocket server!', |
|
|
'scenes': game_manager.valid_scenes |
|
|
}) |
|
|
|
|
|
try: |
|
|
async for msg in ws: |
|
|
if msg.type == WSMsgType.TEXT: |
|
|
try: |
|
|
data = json.loads(msg.data) |
|
|
action = data.get('action') |
|
|
|
|
|
if action == 'ping': |
|
|
|
|
|
await ws.send_json({ |
|
|
'action': 'pong', |
|
|
'requestId': data.get('requestId'), |
|
|
'timestamp': time.time() |
|
|
}) |
|
|
else: |
|
|
|
|
|
await user_session.action_queue.put(data) |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
logger.error(f"Invalid JSON from user {user_id}: {msg.data}") |
|
|
await ws.send_json({ |
|
|
'error': 'Invalid JSON message', |
|
|
'success': False |
|
|
}) |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}") |
|
|
await ws.send_json({ |
|
|
'action': data.get('action') if 'data' in locals() else 'unknown', |
|
|
'success': False, |
|
|
'error': f'Error processing message: {str(e)}' |
|
|
}) |
|
|
|
|
|
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): |
|
|
break |
|
|
|
|
|
finally: |
|
|
|
|
|
await game_manager.delete_session(user_id) |
|
|
logger.info(f"Connection closed for user {user_id}") |
|
|
|
|
|
return ws |
|
|
|
|
|
async def init_app() -> web.Application: |
|
|
app = web.Application( |
|
|
client_max_size=1024**2*10 |
|
|
) |
|
|
|
|
|
|
|
|
async def cleanup(app): |
|
|
logger.info("Shutting down server, closing all sessions...") |
|
|
await game_manager.close_all_sessions() |
|
|
|
|
|
app.on_shutdown.append(cleanup) |
|
|
|
|
|
|
|
|
app.router.add_get('/ws', websocket_handler) |
|
|
app.router.add_get('/api/status', status_handler) |
|
|
|
|
|
|
|
|
app.router.add_static('/client', path=pathlib.Path(__file__).parent / 'client') |
|
|
|
|
|
return app |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to") |
|
|
parser.add_argument("--port", type=int, default=8080, help="Port to listen on") |
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = parse_args() |
|
|
app = asyncio.run(init_app()) |
|
|
web.run_app(app, host=args.host, port=args.port) |