jbilcke-hf commited on
Commit
4315c88
·
verified ·
1 Parent(s): d9ee793

Upload 5 files

Browse files
Files changed (5) hide show
  1. requirements.txt +20 -0
  2. run_inference.sh +22 -0
  3. server.py +586 -0
  4. teacache_forward.py +353 -0
  5. tools/visualize.py +190 -0
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.32.2
2
+ einops==0.8.1
3
+ flash_attn==2.7.4.post1
4
+ ftfy==6.3.1
5
+ imageio==2.34.0
6
+ numpy==1.24.4
7
+ opencv_python==4.9.0.80
8
+ opencv_python_headless==4.9.0.80
9
+ packaging==25.0
10
+ peft==0.14.0
11
+ Pillow==11.2.1
12
+ regex==2024.11.6
13
+ safetensors==0.5.3
14
+ torch==2.5.1
15
+ torchvision==0.20.1
16
+ torchaudio==2.5.1
17
+ transformers==4.47.1
18
+ aiohttp==3.9.3
19
+ jinja2==3.1.3
20
+ python-multipart==0.0.6
run_inference.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Set environment variable for CUDA memory allocation
4
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
5
+ export MODEL_ROOT="models/matrixgame" # Replace with the actual path to your model directory
6
+ export DIT_PATH="$MODEL_ROOT/dit/"
7
+ export TEXTENC_PATH="$MODEL_ROOT"
8
+ export VAE_PATH="$MODEL_ROOT/vae/"
9
+ export MOUSE_ICON_PATH="$MODEL_ROOT/assets/mouse.png"
10
+ export IMAGE_PATH="initial_image/" # Replace with the actual path to your initial image
11
+ export OUTPUT_PATH="./test"
12
+ export INFERENCE_STEPS=50
13
+ # Execute inference script with parameters
14
+ python inference_bench.py \
15
+ --dit_path $DIT_PATH \
16
+ --textenc_path $TEXTENC_PATH \
17
+ --vae_path $VAE_PATH \
18
+ --mouse_icon_path $MOUSE_ICON_PATH \
19
+ --image_path $IMAGE_PATH \
20
+ --output_path $OUTPUT_PATH \
21
+ --inference_steps $INFERENCE_STEPS \
22
+ --bfloat16
server.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ MatrixGame Websocket Gaming Server
6
+
7
+ This script implements a websocket server for the MatrixGame project,
8
+ allowing real-time streaming of game frames based on player inputs.
9
+ """
10
+
11
+ import asyncio
12
+ import json
13
+ import logging
14
+ import os
15
+ import pathlib
16
+ import time
17
+ import uuid
18
+ import io
19
+ import base64
20
+ from typing import Dict, List, Any, Optional
21
+ import argparse
22
+ import torch
23
+ import numpy as np
24
+ from PIL import Image
25
+ import cv2
26
+ from aiohttp import web, WSMsgType
27
+ from condtions import Bench_actions_76
28
+
29
+ # Configure logging
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
33
+ )
34
+ logger = logging.getLogger(__name__)
35
+
36
+ class FrameGenerator:
37
+ """
38
+ Simplified frame generator for the game.
39
+ In production, this would use the MatrixGame model.
40
+ """
41
+ def __init__(self):
42
+ self.frame_width = 640
43
+ self.frame_height = 360
44
+ self.fps = 16
45
+ self.frame_count = 0
46
+ self.scenes = {
47
+ 'forest': self._load_scene_frames('forest'),
48
+ 'desert': self._load_scene_frames('desert'),
49
+ 'beach': self._load_scene_frames('beach'),
50
+ 'hills': self._load_scene_frames('hills'),
51
+ 'river': self._load_scene_frames('river'),
52
+ 'icy': self._load_scene_frames('icy'),
53
+ 'mushroom': self._load_scene_frames('mushroom'),
54
+ 'plain': self._load_scene_frames('plain')
55
+ }
56
+
57
+ def _load_scene_frames(self, scene_name):
58
+ """Load initial frames for a scene from asset directory"""
59
+ frames = []
60
+ scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}"
61
+
62
+ if os.path.exists(scene_dir):
63
+ image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')])
64
+ for img_file in image_files:
65
+ try:
66
+ img_path = os.path.join(scene_dir, img_file)
67
+ img = Image.open(img_path).convert("RGB")
68
+ img = img.resize((self.frame_width, self.frame_height))
69
+ frames.append(np.array(img))
70
+ except Exception as e:
71
+ logger.error(f"Error loading image {img_file}: {str(e)}")
72
+
73
+ # If no frames were loaded, create a default colored frame with text
74
+ if not frames:
75
+ frame = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100
76
+ # Add scene name as text
77
+ cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
78
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
79
+ frames.append(frame)
80
+
81
+ return frames
82
+
83
+ def get_next_frame(self, scene_name, keyboard_condition=None, mouse_condition=None):
84
+ """
85
+ Generate the next frame based on current conditions.
86
+
87
+ Args:
88
+ scene_name: Name of the current scene
89
+ keyboard_condition: Keyboard input state
90
+ mouse_condition: Mouse input state
91
+
92
+ Returns:
93
+ JPEG bytes of the frame
94
+ """
95
+ scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
96
+
97
+ # In a real implementation, this would use the MatrixGame model to generate frames
98
+ # based on the keyboard_condition and mouse_condition
99
+
100
+ # For the demo, just cycle through the pre-loaded frames
101
+ frame_idx = self.frame_count % len(scene_frames)
102
+ frame = scene_frames[frame_idx].copy()
103
+ self.frame_count += 1
104
+
105
+ # If we have keyboard/mouse conditions, visualize them on the frame
106
+ if keyboard_condition:
107
+ # Visualize keyboard inputs (simple example)
108
+ keys = ["W", "S", "A", "D", "JUMP", "ATTACK"]
109
+ for i, key_pressed in enumerate(keyboard_condition[0]):
110
+ color = (0, 255, 0) if key_pressed else (100, 100, 100)
111
+ cv2.putText(frame, keys[i], (20 + i*100, 30),
112
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
113
+
114
+ if mouse_condition:
115
+ # Visualize mouse movement (simple example)
116
+ mouse_x, mouse_y = mouse_condition[0]
117
+ # Scale mouse values for visualization
118
+ offset_x = int(mouse_x * 100)
119
+ offset_y = int(mouse_y * 100)
120
+ center_x, center_y = self.frame_width // 2, self.frame_height // 2
121
+ cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1)
122
+ cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}",
123
+ (self.frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
124
+
125
+ # Convert frame to JPEG
126
+ success, buffer = cv2.imencode('.jpg', frame)
127
+ if not success:
128
+ logger.error("Failed to encode frame as JPEG")
129
+ # Return a blank frame
130
+ blank = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100
131
+ success, buffer = cv2.imencode('.jpg', blank)
132
+
133
+ return buffer.tobytes()
134
+
135
+ class GameSession:
136
+ """
137
+ Represents a user's gaming session.
138
+ Each WebSocket connection gets its own session with separate queues.
139
+ """
140
+ def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager):
141
+ self.user_id = user_id
142
+ self.ws = ws
143
+ self.game_manager = game_manager
144
+
145
+ # Create action queue for this user session
146
+ self.action_queue = asyncio.Queue()
147
+
148
+ # Session creation time
149
+ self.created_at = time.time()
150
+ self.last_activity = time.time()
151
+
152
+ # Game state
153
+ self.current_scene = "forest" # Default scene
154
+ self.is_streaming = False
155
+ self.stream_task = None
156
+
157
+ # Current input state
158
+ self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack
159
+ self.mouse_state = [0, 0] # x, y
160
+
161
+ self.background_tasks = []
162
+
163
+ async def start(self):
164
+ """Start all the queue processors for this session"""
165
+ self.background_tasks = [
166
+ asyncio.create_task(self._process_action_queue()),
167
+ ]
168
+ logger.info(f"Started game session for user {self.user_id}")
169
+
170
+ async def stop(self):
171
+ """Stop all background tasks for this session"""
172
+ # Stop streaming if active
173
+ if self.is_streaming and self.stream_task:
174
+ self.is_streaming = False
175
+ self.stream_task.cancel()
176
+ try:
177
+ await self.stream_task
178
+ except asyncio.CancelledError:
179
+ pass
180
+
181
+ # Cancel other background tasks
182
+ for task in self.background_tasks:
183
+ task.cancel()
184
+
185
+ try:
186
+ # Wait for tasks to complete cancellation
187
+ await asyncio.gather(*self.background_tasks, return_exceptions=True)
188
+ except asyncio.CancelledError:
189
+ pass
190
+
191
+ logger.info(f"Stopped game session for user {self.user_id}")
192
+
193
+ async def _process_action_queue(self):
194
+ """Process game actions from the queue"""
195
+ while True:
196
+ data = await self.action_queue.get()
197
+ try:
198
+ action_type = data.get('action')
199
+
200
+ if action_type == 'start_stream':
201
+ result = await self._handle_start_stream(data)
202
+ elif action_type == 'stop_stream':
203
+ result = await self._handle_stop_stream(data)
204
+ elif action_type == 'keyboard_input':
205
+ result = await self._handle_keyboard_input(data)
206
+ elif action_type == 'mouse_input':
207
+ result = await self._handle_mouse_input(data)
208
+ elif action_type == 'change_scene':
209
+ result = await self._handle_scene_change(data)
210
+ else:
211
+ result = {
212
+ 'action': action_type,
213
+ 'requestId': data.get('requestId'),
214
+ 'success': False,
215
+ 'error': f'Unknown action: {action_type}'
216
+ }
217
+
218
+ # Send response back to the client
219
+ await self.ws.send_json(result)
220
+
221
+ # Update last activity time
222
+ self.last_activity = time.time()
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing action for user {self.user_id}: {str(e)}")
226
+ try:
227
+ await self.ws.send_json({
228
+ 'action': data.get('action'),
229
+ 'requestId': data.get('requestId', 'unknown'),
230
+ 'success': False,
231
+ 'error': f'Error processing action: {str(e)}'
232
+ })
233
+ except Exception as send_error:
234
+ logger.error(f"Error sending error response: {send_error}")
235
+ finally:
236
+ self.action_queue.task_done()
237
+
238
+ async def _handle_start_stream(self, data: Dict) -> Dict:
239
+ """Handle request to start streaming frames"""
240
+ if self.is_streaming:
241
+ return {
242
+ 'action': 'start_stream',
243
+ 'requestId': data.get('requestId'),
244
+ 'success': False,
245
+ 'error': 'Stream already active'
246
+ }
247
+
248
+ fps = data.get('fps', 16)
249
+ self.is_streaming = True
250
+ self.stream_task = asyncio.create_task(self._stream_frames(fps))
251
+
252
+ return {
253
+ 'action': 'start_stream',
254
+ 'requestId': data.get('requestId'),
255
+ 'success': True,
256
+ 'message': f'Streaming started at {fps} FPS'
257
+ }
258
+
259
+ async def _handle_stop_stream(self, data: Dict) -> Dict:
260
+ """Handle request to stop streaming frames"""
261
+ if not self.is_streaming:
262
+ return {
263
+ 'action': 'stop_stream',
264
+ 'requestId': data.get('requestId'),
265
+ 'success': False,
266
+ 'error': 'No active stream to stop'
267
+ }
268
+
269
+ self.is_streaming = False
270
+ if self.stream_task:
271
+ self.stream_task.cancel()
272
+ try:
273
+ await self.stream_task
274
+ except asyncio.CancelledError:
275
+ pass
276
+ self.stream_task = None
277
+
278
+ return {
279
+ 'action': 'stop_stream',
280
+ 'requestId': data.get('requestId'),
281
+ 'success': True,
282
+ 'message': 'Streaming stopped'
283
+ }
284
+
285
+ async def _handle_keyboard_input(self, data: Dict) -> Dict:
286
+ """Handle keyboard input from client"""
287
+ key = data.get('key', '')
288
+ pressed = data.get('pressed', False)
289
+
290
+ # Map key to keyboard state index
291
+ key_map = {
292
+ 'w': 0, 'forward': 0,
293
+ 's': 1, 'back': 1, 'backward': 1,
294
+ 'a': 2, 'left': 2,
295
+ 'd': 3, 'right': 3,
296
+ 'space': 4, 'jump': 4,
297
+ 'shift': 5, 'attack': 5, 'ctrl': 5
298
+ }
299
+
300
+ if key.lower() in key_map:
301
+ key_idx = key_map[key.lower()]
302
+ self.keyboard_state[key_idx] = 1 if pressed else 0
303
+
304
+ return {
305
+ 'action': 'keyboard_input',
306
+ 'requestId': data.get('requestId'),
307
+ 'success': True,
308
+ 'keyboardState': self.keyboard_state
309
+ }
310
+
311
+ async def _handle_mouse_input(self, data: Dict) -> Dict:
312
+ """Handle mouse movement/input from client"""
313
+ mouse_x = data.get('x', 0)
314
+ mouse_y = data.get('y', 0)
315
+
316
+ # Update mouse state, normalize values between -1 and 1
317
+ self.mouse_state = [float(mouse_x), float(mouse_y)]
318
+
319
+ return {
320
+ 'action': 'mouse_input',
321
+ 'requestId': data.get('requestId'),
322
+ 'success': True,
323
+ 'mouseState': self.mouse_state
324
+ }
325
+
326
+ async def _handle_scene_change(self, data: Dict) -> Dict:
327
+ """Handle scene change requests"""
328
+ scene_name = data.get('scene', 'forest')
329
+ valid_scenes = ['forest', 'desert', 'beach', 'hills', 'river', 'icy', 'mushroom', 'plain']
330
+
331
+ if scene_name not in valid_scenes:
332
+ return {
333
+ 'action': 'change_scene',
334
+ 'requestId': data.get('requestId'),
335
+ 'success': False,
336
+ 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}'
337
+ }
338
+
339
+ self.current_scene = scene_name
340
+
341
+ return {
342
+ 'action': 'change_scene',
343
+ 'requestId': data.get('requestId'),
344
+ 'success': True,
345
+ 'scene': scene_name
346
+ }
347
+
348
+ async def _stream_frames(self, fps: int):
349
+ """Stream frames to the client at the specified FPS"""
350
+ frame_interval = 1.0 / fps # Time between frames in seconds
351
+
352
+ try:
353
+ while self.is_streaming:
354
+ start_time = time.time()
355
+
356
+ # Generate frame based on current keyboard and mouse state
357
+ keyboard_condition = [self.keyboard_state]
358
+ mouse_condition = [self.mouse_state]
359
+
360
+ frame_bytes = self.game_manager.frame_generator.get_next_frame(
361
+ self.current_scene, keyboard_condition, mouse_condition
362
+ )
363
+
364
+ # Encode as base64 for sending in JSON
365
+ frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
366
+
367
+ # Send frame to client
368
+ await self.ws.send_json({
369
+ 'action': 'frame',
370
+ 'frameData': frame_base64,
371
+ 'timestamp': time.time()
372
+ })
373
+
374
+ # Calculate sleep time to maintain FPS
375
+ elapsed = time.time() - start_time
376
+ sleep_time = max(0, frame_interval - elapsed)
377
+ await asyncio.sleep(sleep_time)
378
+
379
+ except asyncio.CancelledError:
380
+ logger.info(f"Frame streaming cancelled for user {self.user_id}")
381
+ except Exception as e:
382
+ logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}")
383
+ if self.ws.closed:
384
+ logger.info(f"WebSocket closed for user {self.user_id}")
385
+ return
386
+
387
+ # Notify client of error
388
+ try:
389
+ await self.ws.send_json({
390
+ 'action': 'frame_error',
391
+ 'error': f'Streaming error: {str(e)}'
392
+ })
393
+ except:
394
+ pass
395
+
396
+ # Stop streaming
397
+ self.is_streaming = False
398
+
399
+ class GameManager:
400
+ """
401
+ Manages all active gaming sessions and shared resources.
402
+ """
403
+ def __init__(self):
404
+ self.sessions = {}
405
+ self.session_lock = asyncio.Lock()
406
+
407
+ # Initialize frame generator
408
+ self.frame_generator = FrameGenerator()
409
+
410
+ # Load valid scenes from FrameGenerator
411
+ self.valid_scenes = list(self.frame_generator.scenes.keys())
412
+
413
+ async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession:
414
+ """Create a new game session"""
415
+ async with self.session_lock:
416
+ # Create a new session for this user
417
+ session = GameSession(user_id, ws, self)
418
+ await session.start()
419
+ self.sessions[user_id] = session
420
+ return session
421
+
422
+ async def delete_session(self, user_id: str) -> None:
423
+ """Delete a game session and clean up resources"""
424
+ async with self.session_lock:
425
+ if user_id in self.sessions:
426
+ session = self.sessions[user_id]
427
+ await session.stop()
428
+ del self.sessions[user_id]
429
+ logger.info(f"Deleted game session for user {user_id}")
430
+
431
+ def get_session(self, user_id: str) -> Optional[GameSession]:
432
+ """Get a game session if it exists"""
433
+ return self.sessions.get(user_id)
434
+
435
+ async def close_all_sessions(self) -> None:
436
+ """Close all active sessions (used during shutdown)"""
437
+ async with self.session_lock:
438
+ for user_id, session in list(self.sessions.items()):
439
+ await session.stop()
440
+ self.sessions.clear()
441
+ logger.info("Closed all active game sessions")
442
+
443
+ @property
444
+ def session_count(self) -> int:
445
+ """Get the number of active sessions"""
446
+ return len(self.sessions)
447
+
448
+ def get_session_stats(self) -> Dict:
449
+ """Get statistics about active sessions"""
450
+ stats = {
451
+ 'total_sessions': len(self.sessions),
452
+ 'active_scenes': {},
453
+ 'streaming_sessions': 0
454
+ }
455
+
456
+ # Count sessions by scene and streaming status
457
+ for session in self.sessions.values():
458
+ scene = session.current_scene
459
+ stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1
460
+ if session.is_streaming:
461
+ stats['streaming_sessions'] += 1
462
+
463
+ return stats
464
+
465
+ # Create global game manager
466
+ game_manager = GameManager()
467
+
468
+ async def status_handler(request: web.Request) -> web.Response:
469
+ """Handler for API status endpoint"""
470
+ # Get session statistics
471
+ session_stats = game_manager.get_session_stats()
472
+
473
+ return web.json_response({
474
+ 'product': 'MatrixGame WebSocket Server',
475
+ 'version': '1.0.0',
476
+ 'active_sessions': session_stats,
477
+ 'available_scenes': game_manager.valid_scenes
478
+ })
479
+
480
+ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
481
+ ws = web.WebSocketResponse(
482
+ max_msg_size=1024*1024*10, # 10MB max message size
483
+ timeout=60.0
484
+ )
485
+
486
+ await ws.prepare(request)
487
+
488
+ # Generate a unique user ID for this connection
489
+ user_id = str(uuid.uuid4())
490
+
491
+ # Get client IP address
492
+ peername = request.transport.get_extra_info('peername')
493
+ if peername is not None:
494
+ client_ip = peername[0]
495
+ else:
496
+ client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
497
+
498
+ logger.info(f"Client {user_id} connecting from IP: {client_ip}")
499
+
500
+ # Store the user ID in the websocket for easy access
501
+ ws.user_id = user_id
502
+
503
+ # Create a new session for this user
504
+ user_session = await game_manager.create_session(user_id, ws)
505
+
506
+ # Send initial welcome message
507
+ await ws.send_json({
508
+ 'action': 'welcome',
509
+ 'userId': user_id,
510
+ 'message': 'Welcome to the MatrixGame WebSocket server!',
511
+ 'scenes': game_manager.valid_scenes
512
+ })
513
+
514
+ try:
515
+ async for msg in ws:
516
+ if msg.type == WSMsgType.TEXT:
517
+ try:
518
+ data = json.loads(msg.data)
519
+ action = data.get('action')
520
+
521
+ if action == 'ping':
522
+ # Respond to ping immediately
523
+ await ws.send_json({
524
+ 'action': 'pong',
525
+ 'requestId': data.get('requestId'),
526
+ 'timestamp': time.time()
527
+ })
528
+ else:
529
+ # Route game actions to the session's action queue
530
+ await user_session.action_queue.put(data)
531
+
532
+ except json.JSONDecodeError:
533
+ logger.error(f"Invalid JSON from user {user_id}: {msg.data}")
534
+ await ws.send_json({
535
+ 'error': 'Invalid JSON message',
536
+ 'success': False
537
+ })
538
+ except Exception as e:
539
+ logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
540
+ await ws.send_json({
541
+ 'action': data.get('action') if 'data' in locals() else 'unknown',
542
+ 'success': False,
543
+ 'error': f'Error processing message: {str(e)}'
544
+ })
545
+
546
+ elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
547
+ break
548
+
549
+ finally:
550
+ # Cleanup session
551
+ await game_manager.delete_session(user_id)
552
+ logger.info(f"Connection closed for user {user_id}")
553
+
554
+ return ws
555
+
556
+ async def init_app() -> web.Application:
557
+ app = web.Application(
558
+ client_max_size=1024**2*10 # 10MB max size
559
+ )
560
+
561
+ # Add cleanup logic
562
+ async def cleanup(app):
563
+ logger.info("Shutting down server, closing all sessions...")
564
+ await game_manager.close_all_sessions()
565
+
566
+ app.on_shutdown.append(cleanup)
567
+
568
+ # Add routes
569
+ app.router.add_get('/ws', websocket_handler)
570
+ app.router.add_get('/api/status', status_handler)
571
+
572
+ # Set up static file serving for the client demo
573
+ app.router.add_static('/client', path=pathlib.Path(__file__).parent / 'client')
574
+
575
+ return app
576
+
577
+ def parse_args() -> argparse.Namespace:
578
+ parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server")
579
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to")
580
+ parser.add_argument("--port", type=int, default=8080, help="Port to listen on")
581
+ return parser.parse_args()
582
+
583
+ if __name__ == '__main__':
584
+ args = parse_args()
585
+ app = asyncio.run(init_app())
586
+ web.run_app(app, host=args.host, port=args.port)
teacache_forward.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # teacache
2
+ import torch
3
+ import numpy as np
4
+ from typing import Optional, Union, Dict, Any
5
+
6
+ from matrixgame.model_variants.matrixgame_dit_src.modulate_layers import modulate
7
+ from matrixgame.model_variants.matrixgame_dit_src.attenion import attention, get_cu_seqlens
8
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
9
+
10
+
11
+ def teacache_forward(
12
+ self,
13
+ hidden_states: torch.Tensor,
14
+ timestep: torch.Tensor, # Should be in range(0, 1000).
15
+ encoder_hidden_states: torch.Tensor = None,
16
+ encoder_attention_mask: torch.Tensor = None, # Now we don't use it.
17
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
18
+ mouse_condition = None,
19
+ keyboard_condition = None,
20
+ return_dict: bool = True,
21
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
22
+ x = hidden_states
23
+ t = timestep
24
+ text_states, text_states_2 = encoder_hidden_states
25
+ text_mask, test_mask_2 = encoder_attention_mask
26
+ out = {}
27
+ img = x
28
+ txt = text_states
29
+ _, _, ot, oh, ow = x.shape
30
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed(ot, oh, ow)
31
+ tt, th, tw = (
32
+ ot // self.patch_size[0],
33
+ oh // self.patch_size[1],
34
+ ow // self.patch_size[2],
35
+ )
36
+
37
+ # Prepare modulation vectors.
38
+ vec = self.time_in(t)
39
+ if self.i2v_condition_type == "token_replace":
40
+ token_replace_t = torch.zeros_like(t)
41
+ token_replace_vec = self.time_in(token_replace_t)
42
+ frist_frame_token_num = th * tw
43
+ else:
44
+ token_replace_vec = None
45
+ frist_frame_token_num = None
46
+ # text modulation
47
+ #vec_2 = self.vector_in(text_states_2)
48
+ #vec = vec + vec_2
49
+ #if self.i2v_condition_type == "token_replace":
50
+ # token_replace_vec = token_replace_vec + vec_2
51
+
52
+ # guidance modulation
53
+ if self.guidance_embed:
54
+ if guidance is None:
55
+ raise ValueError(
56
+ "Didn't get guidance strength for guidance distilled model."
57
+ )
58
+
59
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
60
+ vec = vec + self.guidance_in(guidance)
61
+
62
+ # Embed image and text.
63
+ img = self.img_in(img)
64
+ if self.text_projection == "linear":
65
+ txt = self.txt_in(txt)
66
+ elif self.text_projection == "single_refiner":
67
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
68
+ else:
69
+ raise NotImplementedError(
70
+ f"Unsupported text_projection: {self.text_projection}"
71
+ )
72
+
73
+ txt_seq_len = txt.shape[1]
74
+ img_seq_len = img.shape[1]
75
+
76
+ # Compute cu_squlens and max_seqlen for flash attention
77
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
78
+ cu_seqlens_kv = cu_seqlens_q
79
+ max_seqlen_q = img_seq_len + txt_seq_len
80
+ max_seqlen_kv = max_seqlen_q
81
+
82
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
83
+
84
+ # teacache
85
+ if self.enable_teacache:
86
+ inp = img.clone()
87
+ vec_ = vec.clone()
88
+ txt_ = txt.clone()
89
+ (
90
+ img_mod1_shift,
91
+ img_mod1_scale,
92
+ img_mod1_gate,
93
+ img_mod2_shift,
94
+ img_mod2_scale,
95
+ img_mod2_gate,
96
+ ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
97
+ normed_inp = self.double_blocks[0].img_norm1(inp)
98
+ modulated_inp = modulate(
99
+ normed_inp, shift=img_mod1_shift, scale=img_mod1_scale
100
+ )
101
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
102
+ should_calc = True
103
+ self.accumulated_rel_l1_distance = 0
104
+ else:
105
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
106
+ #coefficients = [-296.53, 191.67, -39.037, 3.705, -0.0383]
107
+ rescale_func = np.poly1d(coefficients)
108
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
109
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
110
+ should_calc = False
111
+ else:
112
+ should_calc = True
113
+ self.accumulated_rel_l1_distance = 0
114
+ self.previous_modulated_input = modulated_inp
115
+ self.cnt += 1
116
+ if self.cnt == self.num_steps:
117
+ self.cnt = 0
118
+
119
+ if self.enable_teacache:
120
+ if not should_calc:
121
+ img += self.previous_residual
122
+ else:
123
+ ori_img = img.clone()
124
+ # --------------------- Pass through DiT blocks ------------------------
125
+ for _, block in enumerate(self.double_blocks):
126
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
127
+ def create_custom_forward(module):
128
+ def custom_forward(*inputs):
129
+ return module(*inputs)
130
+
131
+ return custom_forward
132
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
133
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
134
+ "th":hidden_states.shape[3] // self.patch_size[1],
135
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
136
+ img, txt = torch.utils.checkpoint.checkpoint(
137
+ create_custom_forward(block),
138
+ img,
139
+ txt,
140
+ vec,
141
+ cu_seqlens_q,
142
+ cu_seqlens_kv,
143
+ max_seqlen_q,
144
+ max_seqlen_kv,
145
+ freqs_cis,
146
+ image_kwargs,
147
+ mouse_condition,
148
+ keyboard_condition,
149
+ self.i2v_condition_type,
150
+ token_replace_vec,
151
+ frist_frame_token_num,
152
+ **ckpt_kwargs,
153
+ )
154
+ else:
155
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
156
+ "th":hidden_states.shape[3] // self.patch_size[1],
157
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
158
+ double_block_args = [
159
+ img,
160
+ txt,
161
+ vec,
162
+ cu_seqlens_q,
163
+ cu_seqlens_kv,
164
+ max_seqlen_q,
165
+ max_seqlen_kv,
166
+ freqs_cis,
167
+ image_kwargs,
168
+ mouse_condition,
169
+ keyboard_condition,
170
+ self.i2v_condition_type,
171
+ token_replace_vec,
172
+ frist_frame_token_num,
173
+ ]
174
+
175
+ img, txt = block(*double_block_args)
176
+
177
+ # Merge txt and img to pass through single stream blocks.
178
+ x = torch.cat((img, txt), 1)
179
+ if len(self.single_blocks) > 0:
180
+ for _, block in enumerate(self.single_blocks):
181
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
182
+ def create_custom_forward(module):
183
+ def custom_forward(*inputs):
184
+ return module(*inputs)
185
+
186
+ return custom_forward
187
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
189
+ "th":hidden_states.shape[3] // self.patch_size[1],
190
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
191
+ x = torch.utils.checkpoint.checkpoint(
192
+ create_custom_forward(block),
193
+ x,
194
+ vec,
195
+ txt_seq_len,
196
+ cu_seqlens_q,
197
+ cu_seqlens_kv,
198
+ max_seqlen_q,
199
+ max_seqlen_kv,
200
+ (freqs_cos, freqs_sin),
201
+ image_kwargs,
202
+ mouse_condition,
203
+ keyboard_condition,
204
+ self.i2v_condition_type,
205
+ token_replace_vec,
206
+ frist_frame_token_num,
207
+ **ckpt_kwargs,
208
+ )
209
+ else:
210
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
211
+ "th":hidden_states.shape[3] // self.patch_size[1],
212
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
213
+ single_block_args = [
214
+ x,
215
+ vec,
216
+ txt_seq_len,
217
+ cu_seqlens_q,
218
+ cu_seqlens_kv,
219
+ max_seqlen_q,
220
+ max_seqlen_kv,
221
+ (freqs_cos, freqs_sin),
222
+ image_kwargs,
223
+ mouse_condition,
224
+ keyboard_condition,
225
+ self.i2v_condition_type,
226
+ token_replace_vec,
227
+ frist_frame_token_num,
228
+ ]
229
+
230
+ x = block(*single_block_args)
231
+
232
+ img = x[:, :img_seq_len, ...]
233
+ self.previous_residual = img - ori_img
234
+ else:
235
+ # --------------------- Pass through DiT blocks ------------------------
236
+ for _, block in enumerate(self.double_blocks):
237
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
238
+ def create_custom_forward(module):
239
+ def custom_forward(*inputs):
240
+ return module(*inputs)
241
+
242
+ return custom_forward
243
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
244
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
245
+ "th":hidden_states.shape[3] // self.patch_size[1],
246
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
247
+ img, txt = torch.utils.checkpoint.checkpoint(
248
+ create_custom_forward(block),
249
+ img,
250
+ txt,
251
+ vec,
252
+ cu_seqlens_q,
253
+ cu_seqlens_kv,
254
+ max_seqlen_q,
255
+ max_seqlen_kv,
256
+ freqs_cis,
257
+ image_kwargs,
258
+ mouse_condition,
259
+ keyboard_condition,
260
+ self.i2v_condition_type,
261
+ token_replace_vec,
262
+ frist_frame_token_num,
263
+ **ckpt_kwargs,
264
+ )
265
+ else:
266
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
267
+ "th":hidden_states.shape[3] // self.patch_size[1],
268
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
269
+ double_block_args = [
270
+ img,
271
+ txt,
272
+ vec,
273
+ cu_seqlens_q,
274
+ cu_seqlens_kv,
275
+ max_seqlen_q,
276
+ max_seqlen_kv,
277
+ freqs_cis,
278
+ image_kwargs,
279
+ mouse_condition,
280
+ keyboard_condition,
281
+ self.i2v_condition_type,
282
+ token_replace_vec,
283
+ frist_frame_token_num,
284
+ ]
285
+
286
+ img, txt = block(*double_block_args)
287
+
288
+ # Merge txt and img to pass through single stream blocks.
289
+ x = torch.cat((img, txt), 1)
290
+ if len(self.single_blocks) > 0:
291
+ for _, block in enumerate(self.single_blocks):
292
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
293
+ def create_custom_forward(module):
294
+ def custom_forward(*inputs):
295
+ return module(*inputs)
296
+
297
+ return custom_forward
298
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
299
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
300
+ "th":hidden_states.shape[3] // self.patch_size[1],
301
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
302
+ x = torch.utils.checkpoint.checkpoint(
303
+ create_custom_forward(block),
304
+ x,
305
+ vec,
306
+ txt_seq_len,
307
+ cu_seqlens_q,
308
+ cu_seqlens_kv,
309
+ max_seqlen_q,
310
+ max_seqlen_kv,
311
+ (freqs_cos, freqs_sin),
312
+ image_kwargs,
313
+ mouse_condition,
314
+ keyboard_condition,
315
+ self.i2v_condition_type,
316
+ token_replace_vec,
317
+ frist_frame_token_num,
318
+ **ckpt_kwargs,
319
+ )
320
+ else:
321
+ image_kwargs: Dict[str, Any] = {"tt":hidden_states.shape[2] // self.patch_size[0],
322
+ "th":hidden_states.shape[3] // self.patch_size[1],
323
+ "tw":hidden_states.shape[4] // self.patch_size[2]}
324
+ single_block_args = [
325
+ x,
326
+ vec,
327
+ txt_seq_len,
328
+ cu_seqlens_q,
329
+ cu_seqlens_kv,
330
+ max_seqlen_q,
331
+ max_seqlen_kv,
332
+ (freqs_cos, freqs_sin),
333
+ image_kwargs,
334
+ mouse_condition,
335
+ keyboard_condition,
336
+ self.i2v_condition_type,
337
+ token_replace_vec,
338
+ frist_frame_token_num,
339
+ ]
340
+
341
+ x = block(*single_block_args)
342
+
343
+ img = x[:, :img_seq_len, ...]
344
+
345
+ # ---------------------------- Final layer ------------------------------
346
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
347
+
348
+ img = self.unpatchify(img, tt, th, tw)
349
+ if return_dict:
350
+ out["x"] = img
351
+ return out
352
+ return (img,)
353
+
tools/visualize.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import index
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import subprocess
6
+ from diffusers.utils import export_to_video
7
+
8
+ def parse_config(config):
9
+ """
10
+ 根据配置生成按键数据和鼠标数据
11
+ - config: list_actions[i] 的配置
12
+ - 返回: key_data 和 mouse_data
13
+ """
14
+ key_data = {}
15
+ mouse_data = {}
16
+
17
+ # 解析 Space 按键的帧范围
18
+ space_frames = set()
19
+ key, mouse = config
20
+
21
+ for i in range(len(mouse)):
22
+
23
+ if len(key[i])==7:
24
+ w, s, a, d, space, attack, _ = key[i]
25
+ else:
26
+ w, s, a, d, space, attack = key[i]
27
+
28
+ mouse_y, mouse_x = mouse[i]
29
+ mouse_y = -1 * mouse_y
30
+
31
+ # 按键状态
32
+ key_data[i] = {
33
+ "W": bool(w),
34
+ "A": bool(a),
35
+ "S": bool(s),
36
+ "D": bool(d),
37
+ "Space": bool(space),
38
+ "Attack": bool(attack),
39
+ }
40
+ # 鼠标位置
41
+ if i == 0:
42
+ mouse_data[i] = (320, 176) # 默认初始位置
43
+ else:
44
+ global_scale_factor = 0.2
45
+ mouse_scale_x = 15 * global_scale_factor
46
+ mouse_scale_y = 15 * 4 * global_scale_factor
47
+ mouse_data[i] = (
48
+ mouse_data[i-1][0] + mouse_x * mouse_scale_x, # x 坐标累计
49
+ mouse_data[i-1][1] + mouse_y * mouse_scale_y, # y 坐标累计
50
+ )
51
+
52
+ return key_data, mouse_data
53
+
54
+
55
+ # 绘制圆角矩形
56
+ def draw_rounded_rectangle(image, top_left, bottom_right, color, radius=10, alpha=0.5):
57
+ overlay = image.copy()
58
+ x1, y1 = top_left
59
+ x2, y2 = bottom_right
60
+
61
+ cv2.rectangle(overlay, (x1 + radius, y1), (x2 - radius, y2), color, -1)
62
+ cv2.rectangle(overlay, (x1, y1 + radius), (x2, y2 - radius), color, -1)
63
+
64
+ cv2.ellipse(overlay, (x1 + radius, y1 + radius), (radius, radius), 180, 0, 90, color, -1)
65
+ cv2.ellipse(overlay, (x2 - radius, y1 + radius), (radius, radius), 270, 0, 90, color, -1)
66
+ cv2.ellipse(overlay, (x1 + radius, y2 - radius), (radius, radius), 90, 0, 90, color, -1)
67
+ cv2.ellipse(overlay, (x2 - radius, y2 - radius), (radius, radius), 0, 0, 90, color, -1)
68
+
69
+ cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
70
+
71
+ # 在帧上绘制按键
72
+ def draw_keys_on_frame(frame, keys, key_size=(80, 50), spacing=20, bottom_margin=30):
73
+ h, w, _ = frame.shape
74
+ horison_shift = 90
75
+ vertical_shift = -20
76
+ horizon_shift_all = 50
77
+ key_positions = {
78
+ "W": (w // 2 - key_size[0] // 2 - horison_shift - horizon_shift_all + spacing* 2, h - bottom_margin - key_size[1] * 2 + vertical_shift - 20),
79
+ "A": (w // 2 - key_size[0] * 2 + 5 - horison_shift - horizon_shift_all+ spacing* 2, h - bottom_margin - key_size[1] + vertical_shift),
80
+ "S": (w // 2 - key_size[0] // 2 - horison_shift - horizon_shift_all+ spacing* 2, h - bottom_margin - key_size[1] + vertical_shift),
81
+ "D": (w // 2 + key_size[0] - 5 - horison_shift - horizon_shift_all+ spacing* 2, h - bottom_margin - key_size[1] + vertical_shift),
82
+ "Space": (w // 2 + key_size[0] * 2 + spacing * 4 - horison_shift - horizon_shift_all , h - bottom_margin - key_size[1] + vertical_shift),
83
+ "Attack": (w // 2 + key_size[0] * 3 + spacing * 9 - horison_shift - horizon_shift_all, h - bottom_margin - key_size[1] + vertical_shift),
84
+ }
85
+
86
+ for key, (x, y) in key_positions.items():
87
+ is_pressed = keys.get(key, False)
88
+ top_left = (x, y)
89
+ if key in ["Space", "Attack"]:
90
+ bottom_right = (x + key_size[0]+40, y + key_size[1])
91
+ else:
92
+ bottom_right = (x + key_size[0], y + key_size[1])
93
+
94
+ color = (0, 255, 0) if is_pressed else (200, 200, 200)
95
+ alpha = 0.8 if is_pressed else 0.5
96
+
97
+ draw_rounded_rectangle(frame, top_left, bottom_right, color, radius=10, alpha=alpha)
98
+
99
+ text_size = cv2.getTextSize(key, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2)[0]
100
+ if key in ["Space", "Attack"]:
101
+ text_x = x + (key_size[0]+40 - text_size[0]) // 2
102
+ else:
103
+ text_x = x + (key_size[0] - text_size[0]) // 2
104
+ text_y = y + (key_size[1] + text_size[1]) // 2
105
+ cv2.putText(frame, key, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
106
+
107
+ # 在帧上叠加鼠标图案
108
+ def overlay_icon(frame, icon, position, scale=1.0, rotation=0):
109
+ x, y = position
110
+ h, w, _ = icon.shape
111
+
112
+ # 缩放图标
113
+ scaled_width = int(w * scale)
114
+ scaled_height = int(h * scale)
115
+ icon_resized = cv2.resize(icon, (scaled_width, scaled_height), interpolation=cv2.INTER_AREA)
116
+
117
+ # 旋转图标
118
+ center = (scaled_width // 2, scaled_height // 2)
119
+ rotation_matrix = cv2.getRotationMatrix2D(center, rotation, 1.0)
120
+ icon_rotated = cv2.warpAffine(icon_resized, rotation_matrix, (scaled_width, scaled_height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0, 0))
121
+
122
+ h, w, _ = icon_rotated.shape
123
+ frame_h, frame_w, _ = frame.shape
124
+
125
+ # 计算绘制区域
126
+ top_left_x = max(0, int(x - w // 2))
127
+ top_left_y = max(0, int(y - h // 2))
128
+ bottom_right_x = min(frame_w, int(x + w // 2))
129
+ bottom_right_y = min(frame_h, int(y + h // 2))
130
+
131
+ icon_x_start = max(0, int(-x + w // 2))
132
+ icon_y_start = max(0, int(-y + h // 2))
133
+ icon_x_end = icon_x_start + (bottom_right_x - top_left_x)
134
+ icon_y_end = icon_y_start + (bottom_right_y - top_left_y)
135
+
136
+ # 提取图标区域
137
+ icon_region = icon_rotated[icon_y_start:icon_y_end, icon_x_start:icon_x_end]
138
+ alpha = icon_region[:, :, 3] / 255.0
139
+ icon_rgb = icon_region[:, :, :3]
140
+
141
+ # 提取帧对应区域
142
+ frame_region = frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x]
143
+
144
+ # 叠加图标
145
+ print(frame_region.shape, icon_rgb.shape, alpha.shape)
146
+ # import ipdb; ipdb.set_trace()
147
+ for c in range(3):
148
+ frame_region[:, :, c] = (1 - alpha) * frame_region[:, :, c] + alpha * icon_rgb[:, :, c]
149
+
150
+ # 替换帧对应区域
151
+ frame[top_left_y:bottom_right_y, top_left_x:bottom_right_x] = frame_region
152
+
153
+
154
+ # 处理视频
155
+ def process_video(input_video, output_video, config, mouse_icon_path, mouse_scale=2.0, mouse_rotation=0,fps=16):
156
+ key_data, mouse_data = parse_config(config)
157
+ fps = fps
158
+ frame_width = input_video[0].shape[1]
159
+ frame_height = input_video[0].shape[0]
160
+ frame_count = len(input_video)
161
+
162
+ mouse_icon = cv2.imread(mouse_icon_path, cv2.IMREAD_UNCHANGED)
163
+ out_video = []
164
+ frame_idx = 0
165
+ for frame in input_video:
166
+ keys = key_data.get(frame_idx, {"W": False, "A": False, "S": False, "D": False, "Space": False, "Attack": False})
167
+ raw_mouse_pos = mouse_data.get(frame_idx, (frame_width // 2 // 2, frame_height // 2 // 2)) # fallback 也用小分辨率中心
168
+ mouse_position = (int(raw_mouse_pos[0] * 2), int(raw_mouse_pos[1] * 2))
169
+ draw_keys_on_frame(frame, keys, key_size=(75, 75), spacing=10, bottom_margin=20)
170
+ overlay_icon(frame, mouse_icon, mouse_position, scale=mouse_scale, rotation=mouse_rotation)
171
+ out_video.append(frame / 255)
172
+ frame_idx += 1
173
+ print(f"Processing frame {frame_idx}/{frame_count}", end="\r")
174
+ export_to_video(out_video, output_video, fps=fps)
175
+ print("\nProcessing complete!")
176
+
177
+ # 处理视频
178
+ def save_video(input_video, output_video, fps=16):
179
+ fps = fps
180
+ frame_width = input_video[0].shape[1]
181
+ frame_height = input_video[0].shape[0]
182
+ frame_count = len(input_video)
183
+ out_video = []
184
+ frame_idx = 0
185
+ for frame in input_video:
186
+ out_video.append(frame / 255)
187
+ frame_idx += 1
188
+ print(f"Processing frame {frame_idx}/{frame_count}", end="\r")
189
+ export_to_video(out_video, output_video, fps=fps)
190
+ print("\nProcessing complete!")