jbilcke-hf HF Staff commited on
Commit
d14e7aa
·
1 Parent(s): 76c6112
example/Dockerfile DELETED
@@ -1,59 +0,0 @@
1
- FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
-
3
- ARG DEBIAN_FRONTEND=noninteractive
4
-
5
- ENV PYTHONUNBUFFERED=1
6
-
7
- RUN apt-get update && apt-get install --no-install-recommends -y \
8
- build-essential \
9
- python3.11 \
10
- python3-pip \
11
- python3-dev \
12
- git \
13
- curl \
14
- ffmpeg \
15
- libglib2.0-0 \
16
- libsm6 \
17
- libxrender1 \
18
- libxext6 \
19
- ninja-build \
20
- && apt-get clean && rm -rf /var/lib/apt/lists/*
21
-
22
- WORKDIR /code
23
-
24
- COPY ./requirements.txt /code/requirements.txt
25
-
26
- # Set up a new user named "user" with user ID 1000
27
- RUN useradd -m -u 1000 user
28
- # Switch to the "user" user
29
- USER user
30
- # Set home to the user's home directory
31
- ENV HOME=/home/user \
32
- PATH=/home/user/.local/bin:$PATH
33
-
34
- # Set Python path and environment variables
35
- ENV PYTHONPATH=$HOME/app \
36
- PYTHONUNBUFFERED=1 \
37
- DATA_ROOT=/tmp/data
38
-
39
- RUN echo "Installing requirements.txt"
40
- RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
41
-
42
- # Install NVIDIA Apex with CUDA and C++ extensions
43
- RUN cd $HOME && \
44
- git clone https://github.com/NVIDIA/apex && \
45
- cd apex && \
46
- NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--parallel" --global-option="8" ./
47
-
48
- WORKDIR $HOME/app
49
-
50
- # Copy all files and set proper ownership
51
- COPY --chown=user . $HOME/app
52
-
53
- # Expose the port that server.py uses (8080)
54
- EXPOSE 8080
55
-
56
- ENV PORT 8080
57
-
58
- # Run the HF space launcher script which sets up the correct paths
59
- CMD ["python3", "run_hf_space.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/client.js DELETED
@@ -1,603 +0,0 @@
1
- // MatrixGame WebSocket Client
2
-
3
- // WebSocket connection
4
- let socket = null;
5
- let userId = null;
6
- let isStreaming = false;
7
- let lastFrameTime = 0;
8
- let frameCount = 0;
9
- let fpsUpdateInterval = null;
10
-
11
- // DOM Elements
12
- const connectBtn = document.getElementById('connect-btn');
13
- const startStreamBtn = document.getElementById('start-stream-btn');
14
- const stopStreamBtn = document.getElementById('stop-stream-btn');
15
- const sceneSelect = document.getElementById('scene-select');
16
- const gameCanvas = document.getElementById('game-canvas');
17
- const connectionLog = document.getElementById('connection-log');
18
- const mousePosition = document.getElementById('mouse-position');
19
- const fpsCounter = document.getElementById('fps-counter');
20
- const mouseTrackingArea = document.getElementById('mouse-tracking-area');
21
-
22
- // Pointer Lock API support check
23
- const pointerLockSupported = 'pointerLockElement' in document ||
24
- 'mozPointerLockElement' in document ||
25
- 'webkitPointerLockElement' in document;
26
-
27
- // Keyboard DOM elements
28
- const keyElements = {
29
- 'w': document.getElementById('key-w'),
30
- 'a': document.getElementById('key-a'),
31
- 's': document.getElementById('key-s'),
32
- 'd': document.getElementById('key-d'),
33
- 'space': document.getElementById('key-space'),
34
- 'shift': document.getElementById('key-shift')
35
- };
36
-
37
- // Key mapping to action names
38
- const keyToAction = {
39
- 'w': 'forward',
40
- 'arrowup': 'forward',
41
- 'a': 'left',
42
- 'arrowleft': 'left',
43
- 's': 'back',
44
- 'arrowdown': 'back',
45
- 'd': 'right',
46
- 'arrowright': 'right',
47
- ' ': 'jump',
48
- 'shift': 'attack'
49
- };
50
-
51
- // Key state tracking
52
- const keyState = {
53
- 'forward': false,
54
- 'back': false,
55
- 'left': false,
56
- 'right': false,
57
- 'jump': false,
58
- 'attack': false
59
- };
60
-
61
- // Mouse state
62
- const mouseState = {
63
- x: 0,
64
- y: 0,
65
- captured: false
66
- };
67
-
68
- // Test server connectivity before establishing WebSocket
69
- async function testServerConnectivity() {
70
- try {
71
- // Get base path by extracting path from the script tag's src attribute
72
- let basePath = '';
73
- const scriptTags = document.getElementsByTagName('script');
74
- for (const script of scriptTags) {
75
- if (script.src.includes('client.js')) {
76
- const url = new URL(script.src);
77
- basePath = url.pathname.replace('/assets/client.js', '');
78
- break;
79
- }
80
- }
81
-
82
- // Try to fetch the debug endpoint to see if the server is accessible
83
- const response = await fetch(`${window.location.protocol}//${window.location.host}${basePath}/api/debug`);
84
- if (!response.ok) {
85
- throw new Error(`Server returned ${response.status}`);
86
- }
87
-
88
- const debugInfo = await response.json();
89
- logMessage(`Server connection test successful! Server time: ${new Date(debugInfo.server_time * 1000).toLocaleTimeString()}`);
90
-
91
- // Log available routes from server
92
- if (debugInfo.all_routes && debugInfo.all_routes.length > 0) {
93
- logMessage(`Available routes: ${debugInfo.all_routes.join(', ')}`);
94
- }
95
-
96
- // Return the debug info for connection setup
97
- return debugInfo;
98
- } catch (error) {
99
- logMessage(`Server connection test failed: ${error.message}`);
100
- return null;
101
- }
102
- }
103
-
104
- // Connect to WebSocket server
105
- async function connectWebSocket() {
106
- // First test connectivity to the server
107
- logMessage('Testing server connectivity...');
108
- const debugInfo = await testServerConnectivity();
109
-
110
- // Use secure WebSocket (wss://) if the page is loaded over HTTPS
111
- const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
112
-
113
- // Get base path by extracting path from the script tag's src attribute
114
- let basePath = '';
115
- if (debugInfo && debugInfo.base_path) {
116
- // Use base path from server if available
117
- basePath = debugInfo.base_path;
118
- logMessage(`Using server-provided base path: ${basePath}`);
119
- } else {
120
- const scriptTags = document.getElementsByTagName('script');
121
- for (const script of scriptTags) {
122
- if (script.src.includes('client.js')) {
123
- const url = new URL(script.src);
124
- basePath = url.pathname.replace('/assets/client.js', '');
125
- break;
126
- }
127
- }
128
- }
129
-
130
- // Try both with and without base path for WebSocket connection
131
- let serverUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}${basePath}/ws`;
132
- logMessage(`Attempting to connect to WebSocket at ${serverUrl}...`);
133
-
134
- // For Hugging Face Spaces, try the direct /ws path if the base path doesn't work
135
- const fallbackUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}/ws`;
136
-
137
- try {
138
- socket = new WebSocket(serverUrl);
139
- setupWebSocketHandlers();
140
-
141
- // Set a timeout to try the fallback URL if the first one doesn't connect
142
- setTimeout(() => {
143
- if (socket.readyState !== WebSocket.OPEN && socket.readyState !== WebSocket.CONNECTING) {
144
- logMessage(`Connection to ${serverUrl} failed. Trying fallback URL: ${fallbackUrl}`);
145
- socket = new WebSocket(fallbackUrl);
146
- setupWebSocketHandlers();
147
- }
148
- }, 3000);
149
- } catch (error) {
150
- logMessage(`Error connecting to WebSocket: ${error.message}`);
151
- resetUI();
152
- }
153
- }
154
-
155
- // Set up WebSocket event handlers
156
- function setupWebSocketHandlers() {
157
- socket.onopen = () => {
158
- logMessage('WebSocket connection established');
159
- connectBtn.textContent = 'Disconnect';
160
- startStreamBtn.disabled = false;
161
- sceneSelect.disabled = false;
162
- };
163
-
164
- socket.onmessage = (event) => {
165
- const message = JSON.parse(event.data);
166
-
167
- switch (message.action) {
168
- case 'welcome':
169
- userId = message.userId;
170
- logMessage(`Connected with user ID: ${userId}`);
171
-
172
- // Update scene options if server provides them
173
- if (message.scenes && Array.isArray(message.scenes)) {
174
- sceneSelect.innerHTML = '';
175
- message.scenes.forEach(scene => {
176
- const option = document.createElement('option');
177
- option.value = scene;
178
- option.textContent = scene.charAt(0).toUpperCase() + scene.slice(1);
179
- sceneSelect.appendChild(option);
180
- });
181
- }
182
- break;
183
-
184
- case 'frame':
185
- // Process incoming frame
186
- processFrame(message);
187
- break;
188
-
189
- case 'start_stream':
190
- if (message.success) {
191
- isStreaming = true;
192
- startStreamBtn.disabled = true;
193
- stopStreamBtn.disabled = false;
194
- logMessage(`Streaming started: ${message.message}`);
195
-
196
- // Start FPS counter
197
- startFpsCounter();
198
- } else {
199
- logMessage(`Error starting stream: ${message.error}`);
200
- }
201
- break;
202
-
203
- case 'stop_stream':
204
- if (message.success) {
205
- isStreaming = false;
206
- startStreamBtn.disabled = false;
207
- stopStreamBtn.disabled = true;
208
- logMessage('Streaming stopped');
209
-
210
- // Stop FPS counter
211
- stopFpsCounter();
212
- } else {
213
- logMessage(`Error stopping stream: ${message.error}`);
214
- }
215
- break;
216
-
217
- case 'pong':
218
- // Server responded to ping
219
- break;
220
-
221
- case 'change_scene':
222
- if (message.success) {
223
- logMessage(`Scene changed to ${message.scene}`);
224
- } else {
225
- logMessage(`Error changing scene: ${message.error}`);
226
- }
227
- break;
228
-
229
- default:
230
- logMessage(`Received message: ${JSON.stringify(message)}`);
231
- }
232
- };
233
-
234
- socket.onclose = (event) => {
235
- logMessage(`WebSocket connection closed (code: ${event.code}, reason: ${event.reason || 'none given'})`);
236
- resetUI();
237
- };
238
-
239
- socket.onerror = (error) => {
240
- logMessage(`WebSocket error. This is often caused by CORS issues or the server being inaccessible.`);
241
- console.error('WebSocket error:', error);
242
- resetUI();
243
- };
244
- }
245
-
246
- // Disconnect from WebSocket server
247
- function disconnectWebSocket() {
248
- if (socket && socket.readyState === WebSocket.OPEN) {
249
- // Stop streaming if active
250
- if (isStreaming) {
251
- sendStopStream();
252
- }
253
-
254
- // Close the socket
255
- socket.close();
256
- logMessage('Disconnected from server');
257
- }
258
- }
259
-
260
- // Start streaming frames
261
- function sendStartStream() {
262
- if (socket && socket.readyState === WebSocket.OPEN) {
263
- socket.send(JSON.stringify({
264
- action: 'start_stream',
265
- requestId: generateRequestId(),
266
- fps: 16 // Default FPS
267
- }));
268
- }
269
- }
270
-
271
- // Stop streaming frames
272
- function sendStopStream() {
273
- if (socket && socket.readyState === WebSocket.OPEN) {
274
- socket.send(JSON.stringify({
275
- action: 'stop_stream',
276
- requestId: generateRequestId()
277
- }));
278
- }
279
- }
280
-
281
- // Send keyboard input to server
282
- function sendKeyboardInput(key, pressed) {
283
- if (socket && socket.readyState === WebSocket.OPEN) {
284
- socket.send(JSON.stringify({
285
- action: 'keyboard_input',
286
- requestId: generateRequestId(),
287
- key: key,
288
- pressed: pressed
289
- }));
290
- }
291
- }
292
-
293
- // Send mouse input to server
294
- function sendMouseInput(x, y) {
295
- if (socket && socket.readyState === WebSocket.OPEN && isStreaming) {
296
- socket.send(JSON.stringify({
297
- action: 'mouse_input',
298
- requestId: generateRequestId(),
299
- x: x,
300
- y: y
301
- }));
302
- }
303
- }
304
-
305
- // Change scene
306
- function sendChangeScene(scene) {
307
- if (socket && socket.readyState === WebSocket.OPEN) {
308
- socket.send(JSON.stringify({
309
- action: 'change_scene',
310
- requestId: generateRequestId(),
311
- scene: scene
312
- }));
313
- }
314
- }
315
-
316
- // Process incoming frame
317
- function processFrame(message) {
318
- // Update FPS calculation
319
- const now = performance.now();
320
- if (lastFrameTime > 0) {
321
- frameCount++;
322
- }
323
- lastFrameTime = now;
324
-
325
- // Update the canvas with the new frame
326
- if (message.frameData) {
327
- gameCanvas.src = `data:image/jpeg;base64,${message.frameData}`;
328
- }
329
- }
330
-
331
- // Generate a random request ID
332
- function generateRequestId() {
333
- return Math.random().toString(36).substring(2, 15);
334
- }
335
-
336
- // Log message to the connection info panel
337
- function logMessage(message) {
338
- const logEntry = document.createElement('div');
339
- logEntry.className = 'log-entry';
340
-
341
- const timestamp = new Date().toLocaleTimeString();
342
- logEntry.textContent = `[${timestamp}] ${message}`;
343
-
344
- connectionLog.appendChild(logEntry);
345
- connectionLog.scrollTop = connectionLog.scrollHeight;
346
-
347
- // Limit number of log entries
348
- while (connectionLog.children.length > 100) {
349
- connectionLog.removeChild(connectionLog.firstChild);
350
- }
351
- }
352
-
353
- // Start FPS counter updates
354
- function startFpsCounter() {
355
- frameCount = 0;
356
- lastFrameTime = 0;
357
-
358
- // Update FPS display every second
359
- fpsUpdateInterval = setInterval(() => {
360
- fpsCounter.textContent = `FPS: ${frameCount}`;
361
- frameCount = 0;
362
- }, 1000);
363
- }
364
-
365
- // Stop FPS counter updates
366
- function stopFpsCounter() {
367
- if (fpsUpdateInterval) {
368
- clearInterval(fpsUpdateInterval);
369
- fpsUpdateInterval = null;
370
- }
371
- fpsCounter.textContent = 'FPS: 0';
372
- }
373
-
374
- // Reset UI to initial state
375
- function resetUI() {
376
- connectBtn.textContent = 'Connect';
377
- startStreamBtn.disabled = true;
378
- stopStreamBtn.disabled = true;
379
- sceneSelect.disabled = true;
380
-
381
- // Reset key indicators
382
- for (const key in keyElements) {
383
- keyElements[key].classList.remove('active');
384
- }
385
-
386
- // Stop FPS counter
387
- stopFpsCounter();
388
-
389
- // Reset streaming state
390
- isStreaming = false;
391
- }
392
-
393
- // Event Listeners
394
- connectBtn.addEventListener('click', () => {
395
- if (socket && socket.readyState === WebSocket.OPEN) {
396
- disconnectWebSocket();
397
- } else {
398
- connectWebSocket();
399
- }
400
- });
401
-
402
- startStreamBtn.addEventListener('click', sendStartStream);
403
- stopStreamBtn.addEventListener('click', sendStopStream);
404
-
405
- sceneSelect.addEventListener('change', () => {
406
- sendChangeScene(sceneSelect.value);
407
- });
408
-
409
- // Keyboard event listeners
410
- document.addEventListener('keydown', (event) => {
411
- const key = event.key.toLowerCase();
412
-
413
- // Map key to action
414
- let action = keyToAction[key];
415
- if (!action && key === ' ') {
416
- action = keyToAction[' ']; // Handle spacebar
417
- }
418
-
419
- if (action && !keyState[action]) {
420
- keyState[action] = true;
421
-
422
- // Update visual indicator
423
- const keyElement = keyElements[key] ||
424
- (key === ' ' ? keyElements['space'] : null) ||
425
- (key === 'shift' ? keyElements['shift'] : null);
426
-
427
- if (keyElement) {
428
- keyElement.classList.add('active');
429
- }
430
-
431
- // Send to server
432
- sendKeyboardInput(action, true);
433
- }
434
-
435
- // Prevent default actions for game controls
436
- if (Object.keys(keyToAction).includes(key) || key === ' ') {
437
- event.preventDefault();
438
- }
439
- });
440
-
441
- document.addEventListener('keyup', (event) => {
442
- const key = event.key.toLowerCase();
443
-
444
- // Map key to action
445
- let action = keyToAction[key];
446
- if (!action && key === ' ') {
447
- action = keyToAction[' ']; // Handle spacebar
448
- }
449
-
450
- if (action && keyState[action]) {
451
- keyState[action] = false;
452
-
453
- // Update visual indicator
454
- const keyElement = keyElements[key] ||
455
- (key === ' ' ? keyElements['space'] : null) ||
456
- (key === 'shift' ? keyElements['shift'] : null);
457
-
458
- if (keyElement) {
459
- keyElement.classList.remove('active');
460
- }
461
-
462
- // Send to server
463
- sendKeyboardInput(action, false);
464
- }
465
- });
466
-
467
- // Mouse capture functions
468
- function requestPointerLock() {
469
- if (!mouseState.captured && pointerLockSupported) {
470
- mouseTrackingArea.requestPointerLock = mouseTrackingArea.requestPointerLock ||
471
- mouseTrackingArea.mozRequestPointerLock ||
472
- mouseTrackingArea.webkitRequestPointerLock;
473
- mouseTrackingArea.requestPointerLock();
474
- logMessage('Mouse captured. Press ESC to release.');
475
- }
476
- }
477
-
478
- function exitPointerLock() {
479
- if (mouseState.captured) {
480
- document.exitPointerLock = document.exitPointerLock ||
481
- document.mozExitPointerLock ||
482
- document.webkitExitPointerLock;
483
- document.exitPointerLock();
484
- logMessage('Mouse released.');
485
- }
486
- }
487
-
488
- // Handle pointer lock change events
489
- document.addEventListener('pointerlockchange', pointerLockChangeHandler);
490
- document.addEventListener('mozpointerlockchange', pointerLockChangeHandler);
491
- document.addEventListener('webkitpointerlockchange', pointerLockChangeHandler);
492
-
493
- function pointerLockChangeHandler() {
494
- if (document.pointerLockElement === mouseTrackingArea ||
495
- document.mozPointerLockElement === mouseTrackingArea ||
496
- document.webkitPointerLockElement === mouseTrackingArea) {
497
- // Pointer is locked, enable mouse movement tracking
498
- mouseState.captured = true;
499
- document.addEventListener('mousemove', handleMouseMovement);
500
- } else {
501
- // Pointer is unlocked, disable mouse movement tracking
502
- mouseState.captured = false;
503
- document.removeEventListener('mousemove', handleMouseMovement);
504
- // Reset mouse state
505
- mouseState.x = 0;
506
- mouseState.y = 0;
507
- mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
508
- throttledSendMouseInput();
509
- }
510
- }
511
-
512
- // Mouse tracking with pointer lock
513
- function handleMouseMovement(event) {
514
- if (mouseState.captured) {
515
- // Use movement for mouse look when captured
516
- const sensitivity = 0.005; // Adjust sensitivity
517
- mouseState.x += event.movementX * sensitivity;
518
- mouseState.y -= event.movementY * sensitivity; // Invert Y for intuitive camera control
519
-
520
- // Clamp values
521
- mouseState.x = Math.max(-1, Math.min(1, mouseState.x));
522
- mouseState.y = Math.max(-1, Math.min(1, mouseState.y));
523
-
524
- // Update display
525
- mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
526
-
527
- // Send to server (throttled)
528
- throttledSendMouseInput();
529
- }
530
- }
531
-
532
- // Mouse click to capture
533
- mouseTrackingArea.addEventListener('click', () => {
534
- if (!mouseState.captured && isStreaming) {
535
- requestPointerLock();
536
- }
537
- });
538
-
539
- // Standard mouse tracking for when pointer is not locked
540
- mouseTrackingArea.addEventListener('mousemove', (event) => {
541
- if (!mouseState.captured) {
542
- // Calculate normalized coordinates relative to the center of the tracking area
543
- const rect = mouseTrackingArea.getBoundingClientRect();
544
- const centerX = rect.width / 2;
545
- const centerY = rect.height / 2;
546
-
547
- // Calculate relative position from center (-1 to 1)
548
- const relX = (event.clientX - rect.left - centerX) / centerX;
549
- const relY = (event.clientY - rect.top - centerY) / centerY;
550
-
551
- // Scale down for smoother movement (similar to conditions.py)
552
- const scaleFactor = 0.05;
553
- mouseState.x = relX * scaleFactor;
554
- mouseState.y = -relY * scaleFactor; // Invert Y for intuitive camera control
555
-
556
- // Update display
557
- mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
558
-
559
- // Send to server (throttled)
560
- throttledSendMouseInput();
561
- }
562
- });
563
-
564
- // Throttle mouse movement to avoid flooding the server
565
- const throttledSendMouseInput = (() => {
566
- let lastSentTime = 0;
567
- const interval = 50; // milliseconds
568
-
569
- return () => {
570
- const now = performance.now();
571
- if (now - lastSentTime >= interval) {
572
- sendMouseInput(mouseState.x, mouseState.y);
573
- lastSentTime = now;
574
- }
575
- };
576
- })();
577
-
578
- // Toggle panel collapse/expand
579
- function togglePanel(panelId) {
580
- const panel = document.getElementById(panelId);
581
- const button = panel.querySelector('.toggle-button');
582
-
583
- if (panel.classList.contains('collapsed')) {
584
- // Expand the panel
585
- panel.classList.remove('collapsed');
586
- button.textContent = '−'; // Minus sign
587
- } else {
588
- // Collapse the panel
589
- panel.classList.add('collapsed');
590
- button.textContent = '+'; // Plus sign
591
- }
592
- }
593
-
594
- // Initialize the UI
595
- resetUI();
596
-
597
- // Make panel headers clickable
598
- document.querySelectorAll('.panel-header').forEach(header => {
599
- header.addEventListener('click', () => {
600
- const panelId = header.parentElement.id;
601
- togglePanel(panelId);
602
- });
603
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/engine.py DELETED
@@ -1,332 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- MatrixGame Engine
6
-
7
- This module handles the core rendering and model inference for the MatrixGame project.
8
- """
9
-
10
- import os
11
- import logging
12
- import argparse
13
- import time
14
- import torch
15
- import numpy as np
16
- from PIL import Image
17
- import cv2
18
- from einops import rearrange
19
- from diffusers.utils import load_image
20
- from diffusers.video_processor import VideoProcessor
21
- from typing import Dict, List, Tuple, Any, Optional, Union
22
-
23
- # MatrixGame specific imports
24
- from matrixgame.sample.pipeline_matrixgame import MatrixGameVideoPipeline
25
- from matrixgame.model_variants import get_dit
26
- from matrixgame.vae_variants import get_vae
27
- from matrixgame.encoder_variants import get_text_enc
28
- from matrixgame.model_variants.matrixgame_dit_src import MGVideoDiffusionTransformerI2V
29
- from matrixgame.sample.flow_matching_scheduler_matrixgame import FlowMatchDiscreteScheduler
30
- from teacache_forward import teacache_forward
31
-
32
- # Import utility functions
33
- from utils import (
34
- visualize_controls,
35
- frame_to_jpeg,
36
- load_scene_frames,
37
- logger
38
- )
39
-
40
- class MatrixGameEngine:
41
- """
42
- Core engine for MatrixGame model inference and frame generation.
43
- """
44
- def __init__(self, args: Optional[argparse.Namespace] = None):
45
- """
46
- Initialize the MatrixGame engine with configuration parameters.
47
-
48
- Args:
49
- args: Optional parsed command line arguments for model configuration
50
- """
51
- # Set default parameters if args not provided
52
- self.frame_width = getattr(args, 'frame_width', 640)
53
- self.frame_height = getattr(args, 'frame_height', 360)
54
- self.fps = getattr(args, 'fps', 16)
55
- self.inference_steps = getattr(args, 'inference_steps', 20)
56
- self.guidance_scale = getattr(args, 'guidance_scale', 6.0)
57
- self.num_pre_frames = getattr(args, 'num_pre_frames', 3)
58
-
59
- # Initialize state
60
- self.frame_count = 0
61
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62
- self.weight_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
63
-
64
- # Model paths from environment or args
65
- self.vae_path = os.environ.get("VAE_PATH", "./models/matrixgame/vae/")
66
- self.dit_path = os.environ.get("DIT_PATH", "./models/matrixgame/dit/")
67
- self.textenc_path = os.environ.get("TEXTENC_PATH", "./models/matrixgame")
68
-
69
- # Cache scene initial frames
70
- self.scenes = {
71
- 'forest': load_scene_frames('forest', self.frame_width, self.frame_height),
72
- 'desert': load_scene_frames('desert', self.frame_width, self.frame_height),
73
- 'beach': load_scene_frames('beach', self.frame_width, self.frame_height),
74
- 'hills': load_scene_frames('hills', self.frame_width, self.frame_height),
75
- 'river': load_scene_frames('river', self.frame_width, self.frame_height),
76
- 'icy': load_scene_frames('icy', self.frame_width, self.frame_height),
77
- 'mushroom': load_scene_frames('mushroom', self.frame_width, self.frame_height),
78
- 'plain': load_scene_frames('plain', self.frame_width, self.frame_height)
79
- }
80
-
81
- # Cache initial images for model input
82
- self.scene_initial_images = {}
83
-
84
- # Initialize MatrixGame pipeline
85
- self.model_loaded = False
86
- if torch.cuda.is_available():
87
- try:
88
- self._init_models()
89
- self.model_loaded = True
90
- logger.info("MatrixGame models loaded successfully")
91
- except Exception as e:
92
- logger.error(f"Failed to initialize MatrixGame models: {str(e)}")
93
- logger.info("Falling back to frame cycling mode")
94
- else:
95
- logger.warning("CUDA not available. Using frame cycling mode only.")
96
-
97
- def _init_models(self):
98
- """Initialize MatrixGame models (VAE, text encoder, transformer)"""
99
- # Initialize flow matching scheduler
100
- self.scheduler = FlowMatchDiscreteScheduler(
101
- shift=15.0,
102
- reverse=True,
103
- solver="euler"
104
- )
105
-
106
- # Initialize VAE
107
- try:
108
- self.vae = get_vae("matrixgame", self.vae_path, self.weight_dtype)
109
- self.vae.requires_grad_(False)
110
- self.vae.eval()
111
- self.vae.enable_tiling()
112
- logger.info("VAE model loaded successfully")
113
- except Exception as e:
114
- logger.error(f"Error loading VAE model: {str(e)}")
115
- raise
116
-
117
- # Initialize DIT (Transformer)
118
- try:
119
- dit = MGVideoDiffusionTransformerI2V.from_pretrained(self.dit_path)
120
- dit.requires_grad_(False)
121
- dit.eval()
122
- logger.info("DIT model loaded successfully")
123
- except Exception as e:
124
- logger.error(f"Error loading DIT model: {str(e)}")
125
- raise
126
-
127
- # Initialize text encoder
128
- try:
129
- self.text_enc = get_text_enc('matrixgame', self.textenc_path, weight_dtype=self.weight_dtype, i2v_type='refiner')
130
- logger.info("Text encoder loaded successfully")
131
- except Exception as e:
132
- logger.error(f"Error loading text encoder: {str(e)}")
133
- raise
134
-
135
- # Initialize pipeline
136
- try:
137
- self.pipeline = MatrixGameVideoPipeline(
138
- vae=self.vae.vae,
139
- text_encoder=self.text_enc,
140
- transformer=dit,
141
- scheduler=self.scheduler,
142
- ).to(self.weight_dtype).to(self.device)
143
- logger.info("Pipeline initialized successfully")
144
- except Exception as e:
145
- logger.error(f"Error initializing pipeline: {str(e)}")
146
- raise
147
-
148
- # Configure teacache for the transformer
149
- self.pipeline.transformer.__class__.enable_teacache = True
150
- self.pipeline.transformer.__class__.cnt = 0
151
- self.pipeline.transformer.__class__.num_steps = self.inference_steps
152
- self.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
153
- self.pipeline.transformer.__class__.rel_l1_thresh = 0.075
154
- self.pipeline.transformer.__class__.previous_modulated_input = None
155
- self.pipeline.transformer.__class__.previous_residual = None
156
- self.pipeline.transformer.__class__.forward = teacache_forward
157
-
158
- # Preprocess initial images for all scenes
159
- for scene_name, frames in self.scenes.items():
160
- if frames:
161
- # Use first frame as initial image
162
- self.scene_initial_images[scene_name] = self._preprocess_image(frames[0])
163
-
164
- def _preprocess_image(self, image_array: np.ndarray) -> torch.Tensor:
165
- """
166
- Preprocess an image for the model.
167
-
168
- Args:
169
- image_array: Input image as numpy array
170
-
171
- Returns:
172
- torch.Tensor: Preprocessed image tensor
173
- """
174
- # Convert numpy array to PIL Image if needed
175
- if isinstance(image_array, np.ndarray):
176
- image = Image.fromarray(image_array)
177
- else:
178
- image = image_array
179
-
180
- # Preprocess for VAE
181
- vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, 'vae') else 8
182
- video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor)
183
- initial_image = video_processor.preprocess(image, height=self.frame_height, width=self.frame_width)
184
-
185
- # Add past frames for stability (use same frame repeated)
186
- past_frames = initial_image.repeat(self.num_pre_frames, 1, 1, 1)
187
- initial_image = torch.cat([initial_image, past_frames], dim=0)
188
-
189
- return initial_image
190
-
191
- def generate_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
192
- mouse_condition: Optional[List] = None) -> bytes:
193
- """
194
- Generate the next frame based on current conditions using MatrixGame model.
195
-
196
- Args:
197
- scene_name: Name of the current scene
198
- keyboard_condition: Keyboard input state
199
- mouse_condition: Mouse input state
200
-
201
- Returns:
202
- bytes: JPEG bytes of the frame
203
- """
204
- # Check if model is loaded
205
- if not self.model_loaded or not torch.cuda.is_available():
206
- # Fall back to frame cycling for demo mode or if models failed to load
207
- return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
208
- else:
209
- # Use MatrixGame model for frame generation
210
- try:
211
- # Get initial image for this scene
212
- initial_image = self.scene_initial_images.get(scene_name)
213
- if initial_image is None:
214
- # Use forest as default if we don't have an initial image for this scene
215
- initial_image = self.scene_initial_images.get('forest')
216
- if initial_image is None:
217
- # If we still don't have an initial image, fall back to frame cycling
218
- logger.error(f"No initial image available for scene {scene_name}")
219
- return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
220
-
221
- # Prepare input tensors (move to device and format correctly)
222
- if keyboard_condition is None:
223
- keyboard_condition = [[0, 0, 0, 0, 0, 0]]
224
- if mouse_condition is None:
225
- mouse_condition = [[0, 0]]
226
-
227
- # Convert conditions to tensors
228
- keyboard_tensor = torch.tensor(keyboard_condition, dtype=torch.float32)
229
- mouse_tensor = torch.tensor(mouse_condition, dtype=torch.float32)
230
-
231
- # Move to device and convert to correct dtype
232
- keyboard_tensor = keyboard_tensor.to(self.weight_dtype).to(self.device)
233
- mouse_tensor = mouse_tensor.to(self.weight_dtype).to(self.device)
234
-
235
- # Get the first frame from the scene for semantic conditioning
236
- scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
237
- if not scene_frames:
238
- return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
239
-
240
- semantic_image = Image.fromarray(scene_frames[0])
241
-
242
- # Get PIL image version of the frame for visualization
243
- for scene_frame in scene_frames:
244
- if isinstance(scene_frame, np.ndarray):
245
- semantic_image = Image.fromarray(scene_frame)
246
- break
247
-
248
- # Generate a single frame with the model
249
- # Use fewer inference steps for interactive frame generation
250
- with torch.no_grad():
251
- # Generate a short video (we'll just use the first frame)
252
- # We're using a short length (3 frames) for real-time performance
253
- video = self.pipeline(
254
- height=self.frame_height,
255
- width=self.frame_width,
256
- video_length=3, # Generate a very short video for speed
257
- mouse_condition=mouse_tensor,
258
- keyboard_condition=keyboard_tensor,
259
- initial_image=initial_image,
260
- num_inference_steps=self.inference_steps,
261
- guidance_scale=self.guidance_scale,
262
- embedded_guidance_scale=None,
263
- data_type="video",
264
- vae_ver='884-16c-hy',
265
- enable_tiling=True,
266
- generator=torch.Generator(device=self.device).manual_seed(42),
267
- i2v_type='refiner',
268
- semantic_images=semantic_image
269
- ).videos[0]
270
-
271
- # Convert video tensor to numpy array (use first frame)
272
- video_frame = video[0].permute(1, 2, 0).cpu().numpy()
273
- video_frame = (video_frame * 255).astype(np.uint8)
274
- frame = video_frame
275
-
276
- # Increment frame counter
277
- self.frame_count += 1
278
-
279
- except Exception as e:
280
- logger.error(f"Error generating frame with MatrixGame model: {str(e)}")
281
- # Fall back to cycling demo frames if model generation fails
282
- return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
283
-
284
- # Add visualization of input controls
285
- frame = visualize_controls(
286
- frame, keyboard_condition, mouse_condition,
287
- self.frame_width, self.frame_height
288
- )
289
-
290
- # Convert frame to JPEG
291
- return frame_to_jpeg(frame, self.frame_height, self.frame_width)
292
-
293
- def _fallback_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
294
- mouse_condition: Optional[List] = None) -> bytes:
295
- """
296
- Generate a fallback frame when model generation fails.
297
-
298
- Args:
299
- scene_name: Name of the current scene
300
- keyboard_condition: Keyboard input state
301
- mouse_condition: Mouse input state
302
-
303
- Returns:
304
- bytes: JPEG bytes of the frame
305
- """
306
- scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
307
- frame_idx = self.frame_count % len(scene_frames)
308
- frame = scene_frames[frame_idx].copy()
309
- self.frame_count += 1
310
-
311
- # Add fallback mode indicator
312
- cv2.putText(frame, "Fallback mode",
313
- (10, self.frame_height - 20),
314
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
315
-
316
- # Add visualization of input controls
317
- frame = visualize_controls(
318
- frame, keyboard_condition, mouse_condition,
319
- self.frame_width, self.frame_height
320
- )
321
-
322
- # Convert frame to JPEG
323
- return frame_to_jpeg(frame, self.frame_height, self.frame_width)
324
-
325
- def get_valid_scenes(self) -> List[str]:
326
- """
327
- Get a list of valid scene names.
328
-
329
- Returns:
330
- List[str]: List of valid scene names
331
- """
332
- return list(self.scenes.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/index.html DELETED
@@ -1,329 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8">
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>MatrixGame Client</title>
7
- <style>
8
- body {
9
- font-family: Arial, sans-serif;
10
- margin: 0;
11
- padding: 0;
12
- background-color: #121212;
13
- color: #e0e0e0;
14
- display: flex;
15
- flex-direction: column;
16
- align-items: center;
17
- user-select: none; /* Disable text selection */
18
- -webkit-user-select: none;
19
- -moz-user-select: none;
20
- -ms-user-select: none;
21
- overflow-x: hidden;
22
- }
23
-
24
- .container {
25
- width: 100%;
26
- max-width: 100%;
27
- display: flex;
28
- flex-direction: column;
29
- align-items: center;
30
- }
31
-
32
- .game-area {
33
- display: flex;
34
- flex-direction: column;
35
- align-items: center;
36
- width: 100%;
37
- max-height: 85vh;
38
- margin: 0;
39
- position: relative;
40
- }
41
-
42
- #mouse-tracking-area {
43
- position: relative;
44
- width: 100%;
45
- height: auto;
46
- cursor: pointer; /* Show cursor as pointer to encourage clicks */
47
- display: flex;
48
- justify-content: center;
49
- align-items: center;
50
- max-height: 85vh;
51
- }
52
-
53
- #game-canvas {
54
- width: 100%;
55
- height: auto;
56
- max-height: 85vh;
57
- object-fit: contain;
58
- background-color: #000;
59
- pointer-events: none; /* Prevent drag on the image */
60
- -webkit-user-drag: none;
61
- -khtml-user-drag: none;
62
- -moz-user-drag: none;
63
- -o-user-drag: none;
64
- user-drag: none;
65
- }
66
-
67
- .controls {
68
- display: flex;
69
- justify-content: space-between;
70
- width: 100%;
71
- max-width: 1200px;
72
- padding: 10px;
73
- background-color: rgba(0, 0, 0, 0.5);
74
- position: absolute;
75
- bottom: 0;
76
- z-index: 10;
77
- box-sizing: border-box;
78
- }
79
-
80
- .panels-container {
81
- display: flex;
82
- width: 100%;
83
- max-width: 1200px;
84
- margin: 10px auto;
85
- gap: 10px;
86
- }
87
-
88
- .panel {
89
- flex: 1;
90
- background-color: #1E1E1E;
91
- border-radius: 5px;
92
- overflow: hidden;
93
- box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
94
- transition: height 0.3s ease;
95
- }
96
-
97
- .panel-header {
98
- background-color: #272727;
99
- padding: 10px 15px;
100
- display: flex;
101
- justify-content: space-between;
102
- align-items: center;
103
- cursor: pointer;
104
- }
105
-
106
- .panel-title {
107
- font-weight: bold;
108
- color: #4CAF50;
109
- }
110
-
111
- .toggle-button {
112
- background: none;
113
- border: none;
114
- color: #e0e0e0;
115
- font-size: 18px;
116
- cursor: pointer;
117
- }
118
-
119
- .toggle-button:focus {
120
- outline: none;
121
- }
122
-
123
- .panel-content {
124
- padding: 15px;
125
- max-height: 300px;
126
- overflow-y: auto;
127
- transition: all 0.3s ease;
128
- }
129
-
130
- .collapsed .panel-content {
131
- max-height: 0;
132
- padding-top: 0;
133
- padding-bottom: 0;
134
- overflow: hidden;
135
- }
136
-
137
- button {
138
- background-color: #4CAF50;
139
- color: white;
140
- border: none;
141
- padding: 10px 15px;
142
- text-align: center;
143
- text-decoration: none;
144
- display: inline-block;
145
- font-size: 14px;
146
- border-radius: 5px;
147
- cursor: pointer;
148
- margin: 5px;
149
- transition: background-color 0.3s;
150
- }
151
-
152
- button:hover {
153
- background-color: #45a049;
154
- }
155
-
156
- button:disabled {
157
- background-color: #cccccc;
158
- cursor: not-allowed;
159
- }
160
-
161
- select {
162
- padding: 10px;
163
- border-radius: 5px;
164
- background-color: #2A2A2A;
165
- color: #e0e0e0;
166
- border: 1px solid #4CAF50;
167
- }
168
-
169
- .status {
170
- margin-top: 10px;
171
- color: #4CAF50;
172
- }
173
-
174
- .key-indicators {
175
- display: flex;
176
- justify-content: center;
177
- margin-top: 15px;
178
- }
179
-
180
- .key {
181
- width: 40px;
182
- height: 40px;
183
- margin: 0 5px;
184
- background-color: #2A2A2A;
185
- border: 1px solid #444;
186
- border-radius: 5px;
187
- display: flex;
188
- justify-content: center;
189
- align-items: center;
190
- font-weight: bold;
191
- transition: background-color 0.2s;
192
- }
193
-
194
- .key.active {
195
- background-color: #4CAF50;
196
- color: white;
197
- }
198
-
199
- .key-row {
200
- display: flex;
201
- justify-content: center;
202
- margin: 5px 0;
203
- }
204
-
205
- .spacebar {
206
- width: 150px;
207
- }
208
-
209
- .connection-info {
210
- font-family: monospace;
211
- height: 100%;
212
- overflow-y: auto;
213
- }
214
-
215
- .log-entry {
216
- margin: 5px 0;
217
- padding: 3px;
218
- border-bottom: 1px solid #333;
219
- }
220
-
221
- .fps-counter {
222
- position: absolute;
223
- top: 10px;
224
- right: 10px;
225
- background-color: rgba(0,0,0,0.5);
226
- color: #4CAF50;
227
- padding: 5px;
228
- border-radius: 3px;
229
- font-family: monospace;
230
- z-index: 20;
231
- }
232
-
233
-
234
- #mouse-position {
235
- position: absolute;
236
- top: 10px;
237
- left: 10px;
238
- background-color: rgba(0,0,0,0.5);
239
- color: #4CAF50;
240
- padding: 5px;
241
- border-radius: 3px;
242
- font-family: monospace;
243
- z-index: 20;
244
- }
245
-
246
- @media (max-width: 768px) {
247
- .panels-container {
248
- flex-direction: column;
249
- }
250
- }
251
- </style>
252
- </head>
253
- <body>
254
- <div class="container">
255
- <div class="game-area">
256
- <div id="mouse-tracking-area">
257
- <img id="game-canvas" src="" alt="Game Frame">
258
- <div id="mouse-position">Mouse: 0.00, 0.00</div>
259
- <div class="fps-counter" id="fps-counter">FPS: 0</div>
260
- </div>
261
-
262
- <div class="controls">
263
- <button id="connect-btn">Connect</button>
264
- <button id="start-stream-btn" disabled>Start Stream</button>
265
- <button id="stop-stream-btn" disabled>Stop Stream</button>
266
- <select id="scene-select" disabled>
267
- <option value="forest">Forest</option>
268
- <option value="desert">Desert</option>
269
- <option value="beach">Beach</option>
270
- <option value="hills">Hills</option>
271
- <option value="river">River</option>
272
- <option value="icy">Icy</option>
273
- <option value="mushroom">Mushroom</option>
274
- <option value="plain">Plain</option>
275
- </select>
276
- </div>
277
- </div>
278
-
279
- <div class="panels-container">
280
- <!-- Controls Panel -->
281
- <div class="panel" id="controls-panel">
282
- <div class="panel-header" onclick="togglePanel('controls-panel')">
283
- <div class="panel-title">Keyboard Controls</div>
284
- <button class="toggle-button">−</button>
285
- </div>
286
- <div class="panel-content">
287
- <div class="key-indicators">
288
- <div class="key-row">
289
- <div id="key-w" class="key">W</div>
290
- </div>
291
- <div class="key-row">
292
- <div id="key-a" class="key">A</div>
293
- <div id="key-s" class="key">S</div>
294
- <div id="key-d" class="key">D</div>
295
- </div>
296
- <div class="key-row">
297
- <div id="key-space" class="key spacebar">SPACE</div>
298
- </div>
299
- <div class="key-row">
300
- <div id="key-shift" class="key">SHIFT</div>
301
- </div>
302
- </div>
303
- <p class="status">
304
- W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right<br>
305
- Space = Jump, Shift = Attack<br>
306
- Click on game view to capture mouse (ESC to release)<br>
307
- Mouse = Look around
308
- </p>
309
- </div>
310
- </div>
311
-
312
- <!-- Connection Log Panel -->
313
- <div class="panel" id="log-panel">
314
- <div class="panel-header" onclick="togglePanel('log-panel')">
315
- <div class="panel-title">Connection Log</div>
316
- <button class="toggle-button">−</button>
317
- </div>
318
- <div class="panel-content">
319
- <div class="connection-info" id="connection-log">
320
- <div class="log-entry">Waiting to connect...</div>
321
- </div>
322
- </div>
323
- </div>
324
- </div>
325
- </div>
326
-
327
- <script src="./assets/client.js"></script>
328
- </body>
329
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/requirements.txt DELETED
@@ -1,23 +0,0 @@
1
- diffusers==0.32.2
2
- einops==0.8.1
3
-
4
- #flash_attn==2.7.4.post1
5
- flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
6
-
7
- ftfy==6.3.1
8
- imageio==2.34.0
9
- numpy==1.24.4
10
- opencv_python==4.9.0.80
11
- opencv_python_headless==4.9.0.80
12
- packaging==25.0
13
- peft==0.14.0
14
- Pillow==11.2.1
15
- regex==2024.11.6
16
- safetensors==0.5.3
17
- torch==2.5.1
18
- torchvision==0.20.1
19
- torchaudio==2.5.1
20
- transformers==4.47.1
21
- aiohttp==3.9.3
22
- jinja2==3.1.3
23
- python-multipart==0.0.6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/server.py DELETED
@@ -1,649 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- MatrixGame Websocket Gaming Server
6
-
7
- This script implements a websocket server for the MatrixGame project,
8
- allowing real-time streaming of game frames based on player inputs.
9
- """
10
-
11
- import asyncio
12
- import json
13
- import logging
14
- import os
15
- import pathlib
16
- import time
17
- import uuid
18
- import base64
19
- import argparse
20
- from typing import Dict, List, Any, Optional
21
- from aiohttp import web, WSMsgType
22
-
23
- # Import the game engine
24
- from engine import MatrixGameEngine
25
- from utils import logger, parse_model_args, setup_gpu_environment
26
-
27
- class GameSession:
28
- """
29
- Represents a user's gaming session.
30
- Each WebSocket connection gets its own session with separate queues.
31
- """
32
- def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager):
33
- self.user_id = user_id
34
- self.ws = ws
35
- self.game_manager = game_manager
36
-
37
- # Create action queue for this user session
38
- self.action_queue = asyncio.Queue()
39
-
40
- # Session creation time
41
- self.created_at = time.time()
42
- self.last_activity = time.time()
43
-
44
- # Game state
45
- self.current_scene = "forest" # Default scene
46
- self.is_streaming = False
47
- self.stream_task = None
48
-
49
- # Current input state
50
- self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack
51
- self.mouse_state = [0, 0] # x, y
52
-
53
- self.background_tasks = []
54
-
55
- async def start(self):
56
- """Start all the queue processors for this session"""
57
- self.background_tasks = [
58
- asyncio.create_task(self._process_action_queue()),
59
- ]
60
- logger.info(f"Started game session for user {self.user_id}")
61
-
62
- async def stop(self):
63
- """Stop all background tasks for this session"""
64
- # Stop streaming if active
65
- if self.is_streaming and self.stream_task:
66
- self.is_streaming = False
67
- self.stream_task.cancel()
68
- try:
69
- await self.stream_task
70
- except asyncio.CancelledError:
71
- pass
72
-
73
- # Cancel other background tasks
74
- for task in self.background_tasks:
75
- task.cancel()
76
-
77
- try:
78
- # Wait for tasks to complete cancellation
79
- await asyncio.gather(*self.background_tasks, return_exceptions=True)
80
- except asyncio.CancelledError:
81
- pass
82
-
83
- logger.info(f"Stopped game session for user {self.user_id}")
84
-
85
- async def _process_action_queue(self):
86
- """Process game actions from the queue"""
87
- while True:
88
- data = await self.action_queue.get()
89
- try:
90
- action_type = data.get('action')
91
-
92
- if action_type == 'start_stream':
93
- result = await self._handle_start_stream(data)
94
- elif action_type == 'stop_stream':
95
- result = await self._handle_stop_stream(data)
96
- elif action_type == 'keyboard_input':
97
- result = await self._handle_keyboard_input(data)
98
- elif action_type == 'mouse_input':
99
- result = await self._handle_mouse_input(data)
100
- elif action_type == 'change_scene':
101
- result = await self._handle_scene_change(data)
102
- else:
103
- result = {
104
- 'action': action_type,
105
- 'requestId': data.get('requestId'),
106
- 'success': False,
107
- 'error': f'Unknown action: {action_type}'
108
- }
109
-
110
- # Send response back to the client
111
- await self.ws.send_json(result)
112
-
113
- # Update last activity time
114
- self.last_activity = time.time()
115
-
116
- except Exception as e:
117
- logger.error(f"Error processing action for user {self.user_id}: {str(e)}")
118
- try:
119
- await self.ws.send_json({
120
- 'action': data.get('action'),
121
- 'requestId': data.get('requestId', 'unknown'),
122
- 'success': False,
123
- 'error': f'Error processing action: {str(e)}'
124
- })
125
- except Exception as send_error:
126
- logger.error(f"Error sending error response: {send_error}")
127
- finally:
128
- self.action_queue.task_done()
129
-
130
- async def _handle_start_stream(self, data: Dict) -> Dict:
131
- """Handle request to start streaming frames"""
132
- if self.is_streaming:
133
- return {
134
- 'action': 'start_stream',
135
- 'requestId': data.get('requestId'),
136
- 'success': False,
137
- 'error': 'Stream already active'
138
- }
139
-
140
- fps = data.get('fps', 16)
141
- self.is_streaming = True
142
- self.stream_task = asyncio.create_task(self._stream_frames(fps))
143
-
144
- return {
145
- 'action': 'start_stream',
146
- 'requestId': data.get('requestId'),
147
- 'success': True,
148
- 'message': f'Streaming started at {fps} FPS'
149
- }
150
-
151
- async def _handle_stop_stream(self, data: Dict) -> Dict:
152
- """Handle request to stop streaming frames"""
153
- if not self.is_streaming:
154
- return {
155
- 'action': 'stop_stream',
156
- 'requestId': data.get('requestId'),
157
- 'success': False,
158
- 'error': 'No active stream to stop'
159
- }
160
-
161
- self.is_streaming = False
162
- if self.stream_task:
163
- self.stream_task.cancel()
164
- try:
165
- await self.stream_task
166
- except asyncio.CancelledError:
167
- pass
168
- self.stream_task = None
169
-
170
- return {
171
- 'action': 'stop_stream',
172
- 'requestId': data.get('requestId'),
173
- 'success': True,
174
- 'message': 'Streaming stopped'
175
- }
176
-
177
- async def _handle_keyboard_input(self, data: Dict) -> Dict:
178
- """Handle keyboard input from client"""
179
- key = data.get('key', '')
180
- pressed = data.get('pressed', False)
181
-
182
- # Map key to keyboard state index
183
- key_map = {
184
- 'w': 0, 'forward': 0,
185
- 's': 1, 'back': 1, 'backward': 1,
186
- 'a': 2, 'left': 2,
187
- 'd': 3, 'right': 3,
188
- 'space': 4, 'jump': 4,
189
- 'shift': 5, 'attack': 5, 'ctrl': 5
190
- }
191
-
192
- if key.lower() in key_map:
193
- key_idx = key_map[key.lower()]
194
- self.keyboard_state[key_idx] = 1 if pressed else 0
195
-
196
- return {
197
- 'action': 'keyboard_input',
198
- 'requestId': data.get('requestId'),
199
- 'success': True,
200
- 'keyboardState': self.keyboard_state
201
- }
202
-
203
- async def _handle_mouse_input(self, data: Dict) -> Dict:
204
- """Handle mouse movement/input from client"""
205
- mouse_x = data.get('x', 0)
206
- mouse_y = data.get('y', 0)
207
-
208
- # Update mouse state, normalize values between -1 and 1
209
- self.mouse_state = [float(mouse_x), float(mouse_y)]
210
-
211
- return {
212
- 'action': 'mouse_input',
213
- 'requestId': data.get('requestId'),
214
- 'success': True,
215
- 'mouseState': self.mouse_state
216
- }
217
-
218
- async def _handle_scene_change(self, data: Dict) -> Dict:
219
- """Handle scene change requests"""
220
- scene_name = data.get('scene', 'forest')
221
- valid_scenes = self.game_manager.valid_scenes
222
-
223
- if scene_name not in valid_scenes:
224
- return {
225
- 'action': 'change_scene',
226
- 'requestId': data.get('requestId'),
227
- 'success': False,
228
- 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}'
229
- }
230
-
231
- self.current_scene = scene_name
232
-
233
- return {
234
- 'action': 'change_scene',
235
- 'requestId': data.get('requestId'),
236
- 'success': True,
237
- 'scene': scene_name
238
- }
239
-
240
- async def _stream_frames(self, fps: int):
241
- """Stream frames to the client at the specified FPS"""
242
- frame_interval = 1.0 / fps # Time between frames in seconds
243
-
244
- try:
245
- while self.is_streaming:
246
- start_time = time.time()
247
-
248
- # Generate frame based on current keyboard and mouse state
249
- keyboard_condition = [self.keyboard_state]
250
- mouse_condition = [self.mouse_state]
251
-
252
- # Use the engine to generate the next frame
253
- frame_bytes = self.game_manager.engine.generate_frame(
254
- self.current_scene, keyboard_condition, mouse_condition
255
- )
256
-
257
- # Encode as base64 for sending in JSON
258
- frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
259
-
260
- # Send frame to client
261
- await self.ws.send_json({
262
- 'action': 'frame',
263
- 'frameData': frame_base64,
264
- 'timestamp': time.time()
265
- })
266
-
267
- # Calculate sleep time to maintain FPS
268
- elapsed = time.time() - start_time
269
- sleep_time = max(0, frame_interval - elapsed)
270
- await asyncio.sleep(sleep_time)
271
-
272
- except asyncio.CancelledError:
273
- logger.info(f"Frame streaming cancelled for user {self.user_id}")
274
- except Exception as e:
275
- logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}")
276
- if self.ws.closed:
277
- logger.info(f"WebSocket closed for user {self.user_id}")
278
- return
279
-
280
- # Notify client of error
281
- try:
282
- await self.ws.send_json({
283
- 'action': 'frame_error',
284
- 'error': f'Streaming error: {str(e)}'
285
- })
286
- except:
287
- pass
288
-
289
- # Stop streaming
290
- self.is_streaming = False
291
-
292
- class GameManager:
293
- """
294
- Manages all active gaming sessions and shared resources.
295
- """
296
- def __init__(self, args: argparse.Namespace):
297
- self.sessions = {}
298
- self.session_lock = asyncio.Lock()
299
-
300
- # Initialize game engine
301
- self.engine = MatrixGameEngine(args)
302
-
303
- # Load valid scenes from engine
304
- self.valid_scenes = self.engine.get_valid_scenes()
305
-
306
- async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession:
307
- """Create a new game session"""
308
- async with self.session_lock:
309
- # Create a new session for this user
310
- session = GameSession(user_id, ws, self)
311
- await session.start()
312
- self.sessions[user_id] = session
313
- return session
314
-
315
- async def delete_session(self, user_id: str) -> None:
316
- """Delete a game session and clean up resources"""
317
- async with self.session_lock:
318
- if user_id in self.sessions:
319
- session = self.sessions[user_id]
320
- await session.stop()
321
- del self.sessions[user_id]
322
- logger.info(f"Deleted game session for user {user_id}")
323
-
324
- def get_session(self, user_id: str) -> Optional[GameSession]:
325
- """Get a game session if it exists"""
326
- return self.sessions.get(user_id)
327
-
328
- async def close_all_sessions(self) -> None:
329
- """Close all active sessions (used during shutdown)"""
330
- async with self.session_lock:
331
- for user_id, session in list(self.sessions.items()):
332
- await session.stop()
333
- self.sessions.clear()
334
- logger.info("Closed all active game sessions")
335
-
336
- @property
337
- def session_count(self) -> int:
338
- """Get the number of active sessions"""
339
- return len(self.sessions)
340
-
341
- def get_session_stats(self) -> Dict:
342
- """Get statistics about active sessions"""
343
- stats = {
344
- 'total_sessions': len(self.sessions),
345
- 'active_scenes': {},
346
- 'streaming_sessions': 0
347
- }
348
-
349
- # Count sessions by scene and streaming status
350
- for session in self.sessions.values():
351
- scene = session.current_scene
352
- stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1
353
- if session.is_streaming:
354
- stats['streaming_sessions'] += 1
355
-
356
- return stats
357
-
358
- # Create global game manager
359
- game_manager = None
360
-
361
- async def status_handler(request: web.Request) -> web.Response:
362
- """Handler for API status endpoint"""
363
- # Get session statistics
364
- session_stats = game_manager.get_session_stats()
365
-
366
- return web.json_response({
367
- 'product': 'MatrixGame WebSocket Server',
368
- 'version': '1.0.0',
369
- 'active_sessions': session_stats,
370
- 'available_scenes': game_manager.valid_scenes
371
- })
372
-
373
- async def root_handler(request: web.Request) -> web.Response:
374
- """Handler for serving the client at the root path"""
375
- client_path = pathlib.Path(__file__).parent / 'client' / 'index.html'
376
-
377
- with open(client_path, 'r') as file:
378
- html_content = file.read()
379
-
380
- return web.Response(text=html_content, content_type='text/html')
381
-
382
- async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
383
- """Handle WebSocket connections with robust error handling"""
384
- logger.info(f"WebSocket connection attempt - PATH: {request.path}, QUERY: {request.query_string}")
385
-
386
- # Log request headers at debug level only (could contain sensitive information)
387
- logger.debug(f"WebSocket request headers: {dict(request.headers)}")
388
-
389
- # Prepare a WebSocket response with appropriate settings
390
- ws = web.WebSocketResponse(
391
- max_msg_size=1024*1024*10, # 10MB max message size
392
- timeout=60.0,
393
- heartbeat=30.0 # Add heartbeat to keep connection alive
394
- )
395
-
396
- # Check if WebSocket protocol is supported
397
- if not ws.can_prepare(request):
398
- logger.error("Cannot prepare WebSocket: WebSocket protocol not supported")
399
- return web.Response(status=400, text="WebSocket protocol not supported")
400
-
401
- try:
402
- logger.info("Preparing WebSocket connection...")
403
- await ws.prepare(request)
404
-
405
- # Generate a unique user ID for this connection
406
- user_id = str(uuid.uuid4())
407
-
408
- # Get client IP address
409
- peername = request.transport.get_extra_info('peername')
410
- if peername is not None:
411
- client_ip = peername[0]
412
- else:
413
- client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
414
-
415
- # Log connection success
416
- logger.info(f"Client {user_id} connecting from IP: {client_ip} - WebSocket connection established")
417
-
418
- # Mark that the session is established
419
- is_session_created = False
420
-
421
- try:
422
- # Store the user ID in the websocket for easy access
423
- ws.user_id = user_id
424
-
425
- # Create a new session for this user
426
- logger.info(f"Creating game session for user {user_id}")
427
- user_session = await game_manager.create_session(user_id, ws)
428
- is_session_created = True
429
- logger.info(f"Game session created for user {user_id}")
430
- except Exception as session_error:
431
- logger.error(f"Error creating game session: {str(session_error)}", exc_info=True)
432
- if not ws.closed:
433
- await ws.close(code=1011, message=f"Server error: {str(session_error)}".encode())
434
- if is_session_created:
435
- await game_manager.delete_session(user_id)
436
- return ws
437
- except Exception as e:
438
- logger.error(f"Error establishing WebSocket connection: {str(e)}", exc_info=True)
439
- if not ws.closed and ws.prepared:
440
- await ws.close(code=1011, message=f"Server error: {str(e)}".encode())
441
- return ws
442
-
443
- # Send initial welcome message
444
- try:
445
- await ws.send_json({
446
- 'action': 'welcome',
447
- 'userId': user_id,
448
- 'message': 'Welcome to the MatrixGame WebSocket server!',
449
- 'scenes': game_manager.valid_scenes
450
- })
451
- logger.info(f"Sent welcome message to user {user_id}")
452
- except Exception as welcome_error:
453
- logger.error(f"Error sending welcome message: {str(welcome_error)}")
454
- if not ws.closed:
455
- await ws.close(code=1011, message=b"Failed to send welcome message")
456
- await game_manager.delete_session(user_id)
457
- return ws
458
-
459
- try:
460
- async for msg in ws:
461
- if msg.type == WSMsgType.TEXT:
462
- try:
463
- data = json.loads(msg.data)
464
- action = data.get('action')
465
-
466
- logger.debug(f"Received {action} message from user {user_id}")
467
-
468
- if action == 'ping':
469
- # Respond to ping immediately
470
- await ws.send_json({
471
- 'action': 'pong',
472
- 'requestId': data.get('requestId'),
473
- 'timestamp': time.time()
474
- })
475
- else:
476
- # Route game actions to the session's action queue
477
- await user_session.action_queue.put(data)
478
-
479
- except json.JSONDecodeError:
480
- logger.error(f"Invalid JSON from user {user_id}: {msg.data}")
481
- if not ws.closed:
482
- await ws.send_json({
483
- 'error': 'Invalid JSON message',
484
- 'success': False
485
- })
486
- except Exception as e:
487
- logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
488
- if not ws.closed:
489
- await ws.send_json({
490
- 'action': data.get('action') if 'data' in locals() else 'unknown',
491
- 'success': False,
492
- 'error': f'Error processing message: {str(e)}'
493
- })
494
-
495
- elif msg.type == WSMsgType.ERROR:
496
- logger.error(f"WebSocket error for user {user_id}: {ws.exception()}")
497
- break
498
-
499
- elif msg.type == WSMsgType.CLOSE:
500
- logger.info(f"WebSocket close received for user {user_id} (code: {msg.data}, message: {msg.extra})")
501
- break
502
-
503
- elif msg.type == WSMsgType.CLOSING:
504
- logger.info(f"WebSocket closing for user {user_id}")
505
- break
506
-
507
- elif msg.type == WSMsgType.CLOSED:
508
- logger.info(f"WebSocket already closed for user {user_id}")
509
- break
510
-
511
- except Exception as ws_error:
512
- logger.error(f"Unexpected WebSocket error for user {user_id}: {str(ws_error)}", exc_info=True)
513
- finally:
514
- # Cleanup session
515
- try:
516
- logger.info(f"Cleaning up session for user {user_id}")
517
- await game_manager.delete_session(user_id)
518
- logger.info(f"Connection closed for user {user_id}")
519
- except Exception as cleanup_error:
520
- logger.error(f"Error during session cleanup for user {user_id}: {str(cleanup_error)}")
521
-
522
- return ws
523
-
524
- async def init_app(args, base_path="") -> web.Application:
525
- """Initialize the web application"""
526
- global game_manager
527
-
528
- # Initialize game manager with command line args
529
- game_manager = GameManager(args)
530
-
531
- app = web.Application(
532
- client_max_size=1024**2*10 # 10MB max size
533
- )
534
-
535
- # Add cleanup logic
536
- async def cleanup(app):
537
- logger.info("Shutting down server, closing all sessions...")
538
- await game_manager.close_all_sessions()
539
-
540
- app.on_shutdown.append(cleanup)
541
-
542
- # Add routes with CORS headers for WebSockets
543
- # Configure CORS for all routes
544
- @web.middleware
545
- async def cors_middleware(request, handler):
546
- if request.method == 'OPTIONS':
547
- # Handle preflight requests
548
- resp = web.Response()
549
- resp.headers['Access-Control-Allow-Origin'] = '*'
550
- resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
551
- resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
552
- return resp
553
-
554
- # Normal request, call the handler
555
- resp = await handler(request)
556
-
557
- # Add CORS headers to the response
558
- resp.headers['Access-Control-Allow-Origin'] = '*'
559
- resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
560
- resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
561
- return resp
562
-
563
- app.middlewares.append(cors_middleware)
564
-
565
- # Add a debug endpoint to help diagnose WebSocket issues
566
- async def debug_handler(request):
567
- client_ip = request.remote
568
- headers = dict(request.headers)
569
- server_host = request.host
570
-
571
- debug_info = {
572
- "client_ip": client_ip,
573
- "server_host": server_host,
574
- "headers": headers,
575
- "request_path": request.path,
576
- "server_time": time.time(),
577
- "base_path": base_path,
578
- "websocket_route": f"{base_path}/ws",
579
- "all_routes": [route.name for route in app.router.routes() if route.name],
580
- "server_info": {
581
- "active_sessions": game_manager.session_count,
582
- "available_scenes": game_manager.valid_scenes
583
- }
584
- }
585
-
586
- return web.json_response(debug_info)
587
-
588
- # Set up routes with the base_path
589
- # Add multiple WebSocket routes to ensure compatibility
590
- logger.info(f"Setting up WebSocket route at {base_path}/ws")
591
- app.router.add_get(f'{base_path}/ws', websocket_handler, name='ws_handler')
592
-
593
- # Also add WebSocket route at the root for Hugging Face compatibility
594
- if base_path:
595
- logger.info(f"Adding additional WebSocket route at /ws")
596
- app.router.add_get('/ws', websocket_handler, name='ws_root_handler')
597
-
598
- # Add routes for API and debug endpoints
599
- app.router.add_get(f'{base_path}/api/status', status_handler, name='status_handler')
600
- app.router.add_get(f'{base_path}/api/debug', debug_handler, name='debug_handler')
601
-
602
- # Serve the client at both the base path and root path for compatibility
603
- app.router.add_get(f'{base_path}/', root_handler, name='root_handler')
604
-
605
- # Always serve at the root path for Hugging Face Spaces compatibility
606
- if base_path:
607
- app.router.add_get('/', root_handler, name='root_handler_no_base')
608
-
609
- # Set up static file serving for the client assets
610
- app.router.add_static(f'{base_path}/assets', pathlib.Path(__file__).parent / 'client', name='static_handler')
611
-
612
- # Add static file serving at root for compatibility
613
- if base_path:
614
- app.router.add_static('/assets', pathlib.Path(__file__).parent / 'client', name='static_handler_no_base')
615
-
616
- return app
617
-
618
- def parse_args() -> argparse.Namespace:
619
- """Parse server-specific command line arguments"""
620
- parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server")
621
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to")
622
- parser.add_argument("--port", type=int, default=8080, help="Port to listen on")
623
- parser.add_argument("--path", type=str, default="", help="Base path for the server (for proxy setups)")
624
-
625
- # Parse server args first
626
- server_args, remaining_args = parser.parse_known_args()
627
-
628
- # Parse model args and combine
629
- model_args = parse_model_args()
630
-
631
- # Combine all args
632
- combined_args = argparse.Namespace(**vars(server_args), **vars(model_args))
633
-
634
- return combined_args
635
-
636
- if __name__ == '__main__':
637
- # Configure GPU environment
638
- setup_gpu_environment()
639
-
640
- # Parse command line arguments
641
- args = parse_args()
642
-
643
- # Initialize app
644
- loop = asyncio.get_event_loop()
645
- app = loop.run_until_complete(init_app(args, base_path=args.path))
646
-
647
- # Start server
648
- logger.info(f"Starting MatrixGame WebSocket Server at {args.host}:{args.port}")
649
- web.run_app(app, host=args.host, port=args.port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example/utils.py DELETED
@@ -1,202 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- MatrixGame Utility Functions
6
-
7
- This module contains helper functions and utilities for the MatrixGame project.
8
- """
9
-
10
- import os
11
- import logging
12
- import argparse
13
- import torch
14
- import numpy as np
15
- import cv2
16
- from PIL import Image
17
- from typing import Dict, List, Tuple, Any, Optional, Union
18
-
19
- # Configure logging
20
- logging.basicConfig(
21
- level=logging.INFO,
22
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
- )
24
- logger = logging.getLogger(__name__)
25
-
26
- def setup_gpu_environment():
27
- """
28
- Configure the GPU environment and log GPU information.
29
-
30
- Returns:
31
- bool: True if CUDA is available, False otherwise
32
- """
33
- # Set CUDA memory allocation environment variable for better performance
34
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
35
-
36
- # Check if CUDA is available and log information
37
- if torch.cuda.is_available():
38
- gpu_count = torch.cuda.device_count()
39
- gpu_info = []
40
-
41
- for i in range(gpu_count):
42
- gpu_name = torch.cuda.get_device_name(i)
43
- gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3) # Convert to GB
44
- gpu_info.append(f"GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)")
45
-
46
- logger.info(f"CUDA is available. Found {gpu_count} GPU(s):")
47
- for info in gpu_info:
48
- logger.info(f" {info}")
49
- return True
50
- else:
51
- logger.warning("CUDA is not available. Running in CPU-only mode.")
52
- return False
53
-
54
- def parse_model_args() -> argparse.Namespace:
55
- """
56
- Parse command line arguments for model paths and configuration.
57
-
58
- Returns:
59
- argparse.Namespace: Parsed arguments
60
- """
61
- parser = argparse.ArgumentParser(description="MatrixGame Model Configuration")
62
-
63
- # Model paths
64
- parser.add_argument("--model_root", type=str, default="./models/matrixgame",
65
- help="Root directory for model files")
66
- parser.add_argument("--dit_path", type=str, default=None,
67
- help="Path to DIT model. If not provided, will use MODEL_ROOT/dit/")
68
- parser.add_argument("--vae_path", type=str, default=None,
69
- help="Path to VAE model. If not provided, will use MODEL_ROOT/vae/")
70
- parser.add_argument("--textenc_path", type=str, default=None,
71
- help="Path to text encoder model. If not provided, will use MODEL_ROOT")
72
-
73
- # Model settings
74
- parser.add_argument("--inference_steps", type=int, default=20,
75
- help="Number of inference steps for frame generation (lower is faster)")
76
- parser.add_argument("--guidance_scale", type=float, default=6.0,
77
- help="Guidance scale for generation")
78
- parser.add_argument("--frame_width", type=int, default=640,
79
- help="Width of the generated frames")
80
- parser.add_argument("--frame_height", type=int, default=360,
81
- help="Height of the generated frames")
82
- parser.add_argument("--num_pre_frames", type=int, default=3,
83
- help="Number of pre-frames for conditioning")
84
- parser.add_argument("--fps", type=int, default=16,
85
- help="Frames per second for video")
86
-
87
- args = parser.parse_args()
88
-
89
- # Set environment variables for model paths if provided
90
- if args.model_root:
91
- os.environ.setdefault("MODEL_ROOT", args.model_root)
92
- if args.dit_path:
93
- os.environ.setdefault("DIT_PATH", args.dit_path)
94
- else:
95
- os.environ.setdefault("DIT_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "dit/"))
96
- if args.vae_path:
97
- os.environ.setdefault("VAE_PATH", args.vae_path)
98
- else:
99
- os.environ.setdefault("VAE_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "vae/"))
100
- if args.textenc_path:
101
- os.environ.setdefault("TEXTENC_PATH", args.textenc_path)
102
- else:
103
- os.environ.setdefault("TEXTENC_PATH", os.environ.get("MODEL_ROOT", "./models/matrixgame"))
104
-
105
- return args
106
-
107
- def visualize_controls(frame: np.ndarray, keyboard_condition: List, mouse_condition: List,
108
- frame_width: int, frame_height: int) -> np.ndarray:
109
- """
110
- Visualize keyboard and mouse controls on the frame.
111
-
112
- Args:
113
- frame: The video frame to visualize on
114
- keyboard_condition: Keyboard state as a list
115
- mouse_condition: Mouse state as a list
116
- frame_width: Width of the frame
117
- frame_height: Height of the frame
118
-
119
- Returns:
120
- np.ndarray: Frame with visualized controls
121
- """
122
- # Clone the frame to avoid modifying the original
123
- frame = frame.copy()
124
-
125
- # If we have keyboard/mouse conditions, visualize them on the frame
126
- if keyboard_condition:
127
- # Visualize keyboard inputs
128
- keys = ["W", "S", "A", "D", "JUMP", "ATTACK"]
129
- for i, key_pressed in enumerate(keyboard_condition[0]):
130
- color = (0, 255, 0) if key_pressed else (100, 100, 100)
131
- cv2.putText(frame, keys[i], (20 + i*100, 30),
132
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
133
-
134
- if mouse_condition:
135
- # Visualize mouse movement
136
- mouse_x, mouse_y = mouse_condition[0]
137
- # Scale mouse values for visualization
138
- offset_x = int(mouse_x * 100)
139
- offset_y = int(mouse_y * 100)
140
- center_x, center_y = frame_width // 2, frame_height // 2
141
- cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1)
142
- cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}",
143
- (frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
144
-
145
- return frame
146
-
147
- def frame_to_jpeg(frame: np.ndarray, frame_height: int, frame_width: int) -> bytes:
148
- """
149
- Convert a frame to JPEG bytes.
150
-
151
- Args:
152
- frame: The video frame to convert
153
- frame_height: Height of the frame for fallback
154
- frame_width: Width of the frame for fallback
155
-
156
- Returns:
157
- bytes: JPEG bytes of the frame
158
- """
159
- success, buffer = cv2.imencode('.jpg', frame)
160
- if not success:
161
- logger.error("Failed to encode frame as JPEG")
162
- # Return a blank frame
163
- blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 100
164
- success, buffer = cv2.imencode('.jpg', blank)
165
-
166
- return buffer.tobytes()
167
-
168
- def load_scene_frames(scene_name: str, frame_width: int, frame_height: int) -> List[np.ndarray]:
169
- """
170
- Load initial frames for a scene from asset directory.
171
-
172
- Args:
173
- scene_name: Name of the scene
174
- frame_width: Width to resize frames to
175
- frame_height: Height to resize frames to
176
-
177
- Returns:
178
- List[np.ndarray]: List of frames as numpy arrays
179
- """
180
- frames = []
181
- scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}"
182
-
183
- if os.path.exists(scene_dir):
184
- image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')])
185
- for img_file in image_files:
186
- try:
187
- img_path = os.path.join(scene_dir, img_file)
188
- img = Image.open(img_path).convert("RGB")
189
- img = img.resize((frame_width, frame_height))
190
- frames.append(np.array(img))
191
- except Exception as e:
192
- logger.error(f"Error loading image {img_file}: {str(e)}")
193
-
194
- # If no frames were loaded, create a default colored frame with text
195
- if not frames:
196
- frame = np.ones((frame_height, frame_height, 3), dtype=np.uint8) * 100
197
- # Add scene name as text
198
- cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
199
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
200
- frames.append(frame)
201
-
202
- return frames
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reference_example/Dockerfile DELETED
@@ -1,52 +0,0 @@
1
- FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
-
3
- ARG DEBIAN_FRONTEND=noninteractive
4
-
5
- ENV PYTHONUNBUFFERED=1
6
-
7
- RUN apt-get update && apt-get install --no-install-recommends -y \
8
- build-essential \
9
- python3.11 \
10
- python3-pip \
11
- python3-dev \
12
- git \
13
- curl \
14
- ffmpeg \
15
- libglib2.0-0 \
16
- libsm6 \
17
- libxrender1 \
18
- libxext6 \
19
- && apt-get clean && rm -rf /var/lib/apt/lists/*
20
-
21
- WORKDIR /code
22
-
23
- COPY ./requirements.txt /code/requirements.txt
24
-
25
- # Set up a new user named "user" with user ID 1000
26
- RUN useradd -m -u 1000 user
27
- # Switch to the "user" user
28
- USER user
29
- # Set home to the user's home directory
30
- ENV HOME=/home/user \
31
- PATH=/home/user/.local/bin:$PATH
32
-
33
- # Set home to the user's home directory
34
- ENV PYTHONPATH=$HOME/app \
35
- PYTHONUNBUFFERED=1 \
36
- DATA_ROOT=/tmp/data
37
-
38
- RUN echo "Installing requirements.txt"
39
- RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
40
-
41
- # yeah.. this is manual for now
42
- #RUN flutter build web
43
-
44
- WORKDIR $HOME/app
45
-
46
- COPY --chown=user . $HOME/app
47
-
48
- EXPOSE 8080
49
-
50
- ENV PORT 8080
51
-
52
- CMD python3 api.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reference_example/api.py DELETED
@@ -1,297 +0,0 @@
1
- import asyncio
2
- import json
3
- import logging
4
- import os
5
- import pathlib
6
- import time
7
- import uuid
8
- from aiohttp import web, WSMsgType
9
- from typing import Dict, Any
10
-
11
- from api_core import VideoGenerationAPI
12
- from api_session import SessionManager
13
- from api_metrics import MetricsTracker
14
- from api_config import *
15
-
16
- # Configure logging
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
- )
21
- logger = logging.getLogger(__name__)
22
-
23
- # Create global session and metrics managers
24
- session_manager = SessionManager()
25
- metrics_tracker = MetricsTracker()
26
-
27
- # Dictionary to track connected anonymous clients by IP address
28
- anon_connections = {}
29
- anon_connection_lock = asyncio.Lock()
30
-
31
- async def status_handler(request: web.Request) -> web.Response:
32
- """Handler for API status endpoint"""
33
- api = session_manager.shared_api
34
-
35
- # Get current busy status of all endpoints
36
- endpoint_statuses = []
37
- for ep in api.endpoint_manager.endpoints:
38
- endpoint_statuses.append({
39
- 'id': ep.id,
40
- 'url': ep.url,
41
- 'busy': ep.busy,
42
- 'last_used': ep.last_used,
43
- 'error_count': ep.error_count,
44
- 'error_until': ep.error_until
45
- })
46
-
47
- # Get session statistics
48
- session_stats = session_manager.get_session_stats()
49
-
50
- # Get metrics
51
- api_metrics = metrics_tracker.get_metrics()
52
-
53
- return web.json_response({
54
- 'product': PRODUCT_NAME,
55
- 'version': PRODUCT_VERSION,
56
- 'maintenance_mode': MAINTENANCE_MODE,
57
- 'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS),
58
- 'endpoint_status': endpoint_statuses,
59
- 'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())),
60
- 'active_sessions': session_stats,
61
- 'metrics': api_metrics
62
- })
63
-
64
- async def metrics_handler(request: web.Request) -> web.Response:
65
- """Handler for detailed metrics endpoint (protected)"""
66
- # Check for API key in header or query param
67
- auth_header = request.headers.get('Authorization', '')
68
- api_key = None
69
-
70
- if auth_header.startswith('Bearer '):
71
- api_key = auth_header[7:]
72
- else:
73
- api_key = request.query.get('key')
74
-
75
- # Validate API key (using SECRET_TOKEN as the API key)
76
- if not api_key or api_key != SECRET_TOKEN:
77
- return web.json_response({
78
- 'error': 'Unauthorized'
79
- }, status=401)
80
-
81
- # Get detailed metrics
82
- detailed_metrics = metrics_tracker.get_detailed_metrics()
83
-
84
- return web.json_response(detailed_metrics)
85
-
86
- async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
87
- # Check if maintenance mode is enabled
88
- if MAINTENANCE_MODE:
89
- # Return an error response indicating maintenance mode
90
- return web.json_response({
91
- 'error': 'Server is in maintenance mode',
92
- 'maintenance': True
93
- }, status=503) # 503 Service Unavailable
94
-
95
- ws = web.WebSocketResponse(
96
- max_msg_size=1024*1024*20, # 20MB max message size
97
- timeout=30.0 # we want to keep things tight and short
98
- )
99
-
100
- await ws.prepare(request)
101
-
102
- # Get the Hugging Face token from query parameters
103
- hf_token = request.query.get('hf_token', '')
104
-
105
- # Generate a unique user ID for this connection
106
- user_id = str(uuid.uuid4())
107
-
108
- # Validate the token and determine the user role
109
- user_role = await session_manager.shared_api.validate_user_token(hf_token)
110
- logger.info(f"User {user_id} connected with role: {user_role}")
111
-
112
- # Get client IP address
113
- peername = request.transport.get_extra_info('peername')
114
- if peername is not None:
115
- client_ip = peername[0]
116
- else:
117
- client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
118
-
119
- logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}")
120
-
121
- # Check for anonymous user connection limits
122
- if user_role == 'anon':
123
- async with anon_connection_lock:
124
- # Track this connection
125
- anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1
126
- # Store the IP so we can clean up later
127
- ws.client_ip = client_ip
128
-
129
- # Log multiple connections from same IP but don't restrict them
130
- if anon_connections[client_ip] > 1:
131
- logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections")
132
-
133
- # Store the user role in the websocket for easy access
134
- ws.user_role = user_role
135
- ws.user_id = user_id
136
-
137
- # Register with metrics
138
- metrics_tracker.register_session(user_id, client_ip)
139
-
140
- # Create a new session for this user
141
- user_session = await session_manager.create_session(user_id, user_role, ws)
142
-
143
- try:
144
- async for msg in ws:
145
- if msg.type == WSMsgType.TEXT:
146
- try:
147
- data = json.loads(msg.data)
148
- action = data.get('action')
149
-
150
- # Check for rate limiting
151
- request_type = 'other'
152
- if action in ['join_chat', 'leave_chat', 'chat_message']:
153
- request_type = 'chat'
154
- elif action in ['generate_video']:
155
- request_type = 'video'
156
- elif action == 'search':
157
- request_type = 'search'
158
- elif action == 'simulate':
159
- request_type = 'simulation'
160
-
161
- # Record the request for metrics
162
- await metrics_tracker.record_request(user_id, client_ip, request_type, user_role)
163
-
164
- # Check rate limits (except for admins)
165
- if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role):
166
- await ws.send_json({
167
- 'action': action,
168
- 'requestId': data.get('requestId'),
169
- 'success': False,
170
- 'error': f'Rate limit exceeded for {request_type} requests. Please try again later.'
171
- })
172
- continue
173
-
174
- # Route requests to appropriate queues
175
- if action in ['join_chat', 'leave_chat', 'chat_message']:
176
- await user_session.chat_queue.put(data)
177
- elif action in ['generate_video']:
178
- await user_session.video_queue.put(data)
179
- elif action == 'search':
180
- await user_session.search_queue.put(data)
181
- elif action == 'simulate':
182
- await user_session.simulation_queue.put(data)
183
- else:
184
- await user_session.process_generic_request(data)
185
-
186
- except Exception as e:
187
- logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
188
- await ws.send_json({
189
- 'action': data.get('action') if 'data' in locals() else 'unknown',
190
- 'success': False,
191
- 'error': f'Error processing message: {str(e)}'
192
- })
193
-
194
- elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
195
- break
196
-
197
- finally:
198
- # Cleanup session
199
- await session_manager.delete_session(user_id)
200
-
201
- # Cleanup anonymous connection tracking
202
- if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'):
203
- client_ip = ws.client_ip
204
- async with anon_connection_lock:
205
- if client_ip in anon_connections:
206
- anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1)
207
- if anon_connections[client_ip] == 0:
208
- del anon_connections[client_ip]
209
- logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}")
210
-
211
- # Unregister from metrics
212
- metrics_tracker.unregister_session(user_id, client_ip)
213
- logger.info(f"Connection closed for user {user_id}")
214
-
215
- return ws
216
-
217
- async def init_app() -> web.Application:
218
- app = web.Application(
219
- client_max_size=1024**2*20 # 20MB max size
220
- )
221
-
222
- # Add cleanup logic
223
- async def cleanup(app):
224
- logger.info("Shutting down server, closing all sessions...")
225
- await session_manager.close_all_sessions()
226
-
227
- app.on_shutdown.append(cleanup)
228
-
229
- # Add routes
230
- app.router.add_get('/ws', websocket_handler)
231
- app.router.add_get('/api/status', status_handler)
232
- app.router.add_get('/api/metrics', metrics_handler)
233
-
234
- # Set up static file serving
235
- # Define the path to the public directory
236
- public_path = pathlib.Path(__file__).parent / 'build' / 'web'
237
- if not public_path.exists():
238
- public_path.mkdir(parents=True, exist_ok=True)
239
-
240
- # Set up static file serving with proper security considerations
241
- async def static_file_handler(request):
242
- # Get the path from the request (removing leading /)
243
- path_parts = request.path.lstrip('/').split('/')
244
-
245
- # Convert to safe path to prevent path traversal attacks
246
- safe_path = public_path.joinpath(*path_parts)
247
-
248
- # Make sure the path is within the public directory (prevent directory traversal)
249
- try:
250
- safe_path = safe_path.resolve()
251
- if not str(safe_path).startswith(str(public_path.resolve())):
252
- return web.HTTPForbidden(text="Access denied")
253
- except (ValueError, FileNotFoundError):
254
- return web.HTTPNotFound()
255
-
256
- # If path is a directory, look for index.html
257
- if safe_path.is_dir():
258
- safe_path = safe_path / 'index.html'
259
-
260
- # Check if the file exists
261
- if not safe_path.exists() or not safe_path.is_file():
262
- # If not found, serve index.html (for SPA routing)
263
- safe_path = public_path / 'index.html'
264
- if not safe_path.exists():
265
- return web.HTTPNotFound()
266
-
267
- # Determine content type based on file extension
268
- content_type = 'text/plain'
269
- ext = safe_path.suffix.lower()
270
- if ext == '.html':
271
- content_type = 'text/html'
272
- elif ext == '.js':
273
- content_type = 'application/javascript'
274
- elif ext == '.css':
275
- content_type = 'text/css'
276
- elif ext in ('.jpg', '.jpeg'):
277
- content_type = 'image/jpeg'
278
- elif ext == '.png':
279
- content_type = 'image/png'
280
- elif ext == '.gif':
281
- content_type = 'image/gif'
282
- elif ext == '.svg':
283
- content_type = 'image/svg+xml'
284
- elif ext == '.json':
285
- content_type = 'application/json'
286
-
287
- # Return the file with appropriate headers
288
- return web.FileResponse(safe_path, headers={'Content-Type': content_type})
289
-
290
- # Add catch-all route for static files (lower priority than API routes)
291
- app.router.add_get('/{path:.*}', static_file_handler)
292
-
293
- return app
294
-
295
- if __name__ == '__main__':
296
- app = asyncio.run(init_app())
297
- web.run_app(app, host='0.0.0.0', port=8080)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reference_example/api_config.py DELETED
@@ -1,184 +0,0 @@
1
- import os
2
-
3
- PRODUCT_NAME = os.environ.get('PRODUCT_NAME', 'TikSlop')
4
- PRODUCT_VERSION = "2.0.0"
5
-
6
- # you should use Mistral 7b instruct for good performance and accuracy balance
7
- TEXT_MODEL = os.environ.get('HF_TEXT_MODEL', '')
8
-
9
- # Environment variable to control maintenance mode
10
- MAINTENANCE_MODE = os.environ.get('MAINTENANCE_MODE', 'false').lower() in ('true', 'yes', '1', 't')
11
-
12
- # Environment variable to control how many nodes to use
13
- MAX_NODES = int(os.environ.get('MAX_NODES', '8'))
14
-
15
- ADMIN_ACCOUNTS = [
16
- "jbilcke-hf"
17
- ]
18
-
19
- RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS = [
20
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_1', ''),
21
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_2', ''),
22
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_3', ''),
23
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_4', ''),
24
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_5', ''),
25
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_6', ''),
26
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_7', ''),
27
- os.environ.get('VIDEO_ROUND_ROBIN_SERVER_8', ''),
28
- ]
29
-
30
- # Filter out empty strings from the endpoint list
31
- filtered_urls = [url for url in RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS if url]
32
-
33
- # Limit the number of URLs based on MAX_NODES environment variable
34
- VIDEO_ROUND_ROBIN_ENDPOINT_URLS = filtered_urls[:MAX_NODES]
35
-
36
- HF_TOKEN = os.environ.get('HF_TOKEN')
37
-
38
- # use the same secret token as you used to secure your BASE_SPACE_NAME spaces
39
- SECRET_TOKEN = os.environ.get('SECRET_TOKEN')
40
-
41
- # altenative words we could use: "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
42
- NEGATIVE_PROMPT = "low quality, worst quality, deformed, distorted, disfigured, blurry, text, watermark"
43
-
44
- POSITIVE_PROMPT_SUFFIX = "high quality, cinematic, 4K, intricate details"
45
-
46
- GUIDANCE_SCALE = 1.0
47
-
48
- THUMBNAIL_FRAMES = 65
49
-
50
- # anonymous users are people browing TikSlop without being connected
51
- # this category suffers from regular abuse so we need to enforce strict limitations
52
- CONFIG_FOR_ANONYMOUS_USERS = {
53
-
54
- # anons can only watch 2 minutes per video
55
- "max_rendering_time_per_client_per_video_in_sec": 2 * 60,
56
-
57
- "min_num_inference_steps": 2,
58
- "default_num_inference_steps": 4,
59
- "max_num_inference_steps": 4,
60
-
61
- "min_num_frames": 9, # 8 + 1
62
- "default_max_num_frames": 65, # 8*8 + 1
63
- "max_num_frames": 65, # 8*8 + 1
64
-
65
- "min_clip_duration_seconds": 1,
66
- "default_clip_duration_seconds": 2,
67
- "max_clip_duration_seconds": 2,
68
-
69
- "min_clip_playback_speed": 0.7,
70
- "default_clip_playback_speed": 0.7,
71
- "max_clip_playback_speed": 0.7,
72
-
73
- "min_clip_framerate": 8,
74
- "default_clip_framerate": 16,
75
- "max_clip_framerate": 16,
76
-
77
- "min_clip_width": 544,
78
- "default_clip_width": 640,
79
- "max_clip_width": 640,
80
-
81
- "min_clip_height": 320,
82
- "default_clip_height": 352,
83
- "max_clip_height": 352,
84
- }
85
-
86
- # Hugging Face users enjoy a more normal and calibrated experience
87
- CONFIG_FOR_STANDARD_HF_USERS = {
88
- "max_rendering_time_per_client_per_video_in_sec": 15 * 60,
89
-
90
- "min_num_inference_steps": 2,
91
- "default_num_inference_steps": 4,
92
- "max_num_inference_steps": 4,
93
-
94
- "min_num_frames": 9, # 8 + 1
95
- "default_num_frames": 81, # 8*10 + 1
96
- "max_num_frames": 81,
97
-
98
- "min_clip_duration_seconds": 1,
99
- "default_clip_duration_seconds": 3,
100
- "max_clip_duration_seconds": 3,
101
-
102
- "min_clip_playback_speed": 0.7,
103
- "default_clip_playback_speed": 0.7,
104
- "max_clip_playback_speed": 0.7,
105
-
106
- "min_clip_framerate": 8,
107
- "default_clip_framerate": 25,
108
- "max_clip_framerate": 25,
109
-
110
- "min_clip_width": 544,
111
- "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
112
- "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
113
-
114
- "min_clip_height": 320,
115
- "default_clip_height": 640, # 512, # 448, # 416,
116
- "max_clip_height": 640, # 512, # 448, # 416,
117
- }
118
-
119
- # Hugging Face users with a Pro may enjoy an improved experience
120
- CONFIG_FOR_PRO_HF_USERS = {
121
- "max_rendering_time_per_client_per_video_in_sec": 20 * 60,
122
-
123
- "min_num_inference_steps": 2,
124
- "default_num_inference_steps": 4,
125
- "max_num_inference_steps": 4,
126
-
127
- "min_num_frames": 9, # 8 + 1
128
- "default_num_frames": 81, # 8*10 + 1
129
- "max_num_frames": 81,
130
-
131
- "min_clip_duration_seconds": 1,
132
- "default_clip_duration_seconds": 3,
133
- "max_clip_duration_seconds": 3,
134
-
135
- "min_clip_playback_speed": 0.7,
136
- "default_clip_playback_speed": 0.7,
137
- "max_clip_playback_speed": 0.7,
138
-
139
- "min_clip_framerate": 8,
140
- "default_clip_framerate": 25,
141
- "max_clip_framerate": 25,
142
-
143
- "min_clip_width": 544,
144
- "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
145
- "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
146
-
147
- "min_clip_height": 320,
148
- "default_clip_height": 640, # 512, # 448, # 416,
149
- "max_clip_height": 640, # 512, # 448, # 416,
150
- }
151
-
152
- CONFIG_FOR_ADMIN_HF_USERS = {
153
- "max_rendering_time_per_client_per_video_in_sec": 60 * 60,
154
-
155
- "min_num_inference_steps": 2,
156
- "default_num_inference_steps": 4,
157
- "max_num_inference_steps": 4,
158
-
159
- "min_num_frames": 9, # 8 + 1
160
- "default_num_frames": 81, # (8 * 10) + 1
161
- "max_num_frames": 129, # (8 * 16) + 1
162
-
163
- "min_clip_duration_seconds": 1,
164
- "default_clip_duration_seconds": 2,
165
- "max_clip_duration_seconds": 4,
166
-
167
- "min_clip_playback_speed": 0.7,
168
- "default_clip_playback_speed": 0.7,
169
- "max_clip_playback_speed": 1.0,
170
-
171
- "min_clip_framerate": 8,
172
- "default_clip_framerate": 30,
173
- "max_clip_framerate": 60,
174
-
175
- "min_clip_width": 544,
176
- "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
177
- "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
178
-
179
- "min_clip_height": 320,
180
- "default_clip_height": 640, # 512, # 448, # 416,
181
- "max_clip_height": 640, # 512, # 448, # 416,
182
- }
183
-
184
- CONFIG_FOR_ADMIN_HF_USERS = CONFIG_FOR_PRO_HF_USERS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reference_example/api_core.py DELETED
@@ -1,1068 +0,0 @@
1
- import logging
2
- import os
3
- import io
4
- import re
5
- import base64
6
- import uuid
7
- from typing import Dict, Any, Optional, List, Literal
8
- from dataclasses import dataclass
9
- from asyncio import Lock, Queue
10
- import asyncio
11
- import time
12
- import datetime
13
- from contextlib import asynccontextmanager
14
- from collections import defaultdict
15
- from aiohttp import web, ClientSession
16
- from huggingface_hub import InferenceClient, HfApi
17
- from gradio_client import Client
18
- import random
19
- import yaml
20
- import json
21
-
22
- from api_config import *
23
-
24
- # User role type
25
- UserRole = Literal['anon', 'normal', 'pro', 'admin']
26
-
27
- # Configure logging
28
- logging.basicConfig(
29
- level=logging.INFO,
30
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
31
- )
32
- logger = logging.getLogger(__name__)
33
-
34
-
35
- def generate_seed():
36
- """Generate a random positive 32-bit integer seed."""
37
- return random.randint(0, 2**32 - 1)
38
-
39
- def sanitize_yaml_response(response_text: str) -> str:
40
- """
41
- Sanitize and format AI response into valid YAML.
42
- Returns properly formatted YAML string.
43
- """
44
-
45
- response_text = response_text.split("```")[0]
46
-
47
- # Remove any markdown code block indicators and YAML document markers
48
- clean_text = re.sub(r'```yaml|```|---|\.\.\.$', '', response_text.strip())
49
-
50
- # Split into lines and process each line
51
- lines = clean_text.split('\n')
52
- sanitized_lines = []
53
- current_field = None
54
-
55
- for line in lines:
56
- stripped = line.strip()
57
- if not stripped:
58
- continue
59
-
60
- # Handle field starts
61
- if stripped.startswith('title:') or stripped.startswith('description:'):
62
- # Ensure proper YAML format with space after colon and proper quoting
63
- field_name = stripped.split(':', 1)[0]
64
- field_value = stripped.split(':', 1)[1].strip().strip('"\'')
65
-
66
- # Quote the value if it contains special characters
67
- if any(c in field_value for c in ':[]{},&*#?|-<>=!%@`'):
68
- field_value = f'"{field_value}"'
69
-
70
- sanitized_lines.append(f"{field_name}: {field_value}")
71
- current_field = field_name
72
-
73
- elif stripped.startswith('tags:'):
74
- sanitized_lines.append('tags:')
75
- current_field = 'tags'
76
-
77
- elif stripped.startswith('-') and current_field == 'tags':
78
- # Process tag values
79
- tag = stripped[1:].strip().strip('"\'')
80
- if tag:
81
- # Clean and format tag
82
- tag = re.sub(r'[^\x00-\x7F]+', '', tag) # Remove non-ASCII
83
- tag = re.sub(r'[^a-zA-Z0-9\s-]', '', tag) # Keep only alphanumeric and hyphen
84
- tag = tag.strip().lower().replace(' ', '-')
85
- if tag:
86
- sanitized_lines.append(f" - {tag}")
87
-
88
- elif current_field in ['title', 'description']:
89
- # Handle multi-line title/description continuation
90
- value = stripped.strip('"\'')
91
- if value:
92
- # Append to previous line
93
- prev = sanitized_lines[-1]
94
- sanitized_lines[-1] = f"{prev} {value}"
95
-
96
- # Ensure the YAML has all required fields
97
- required_fields = {'title', 'description', 'tags'}
98
- found_fields = {line.split(':')[0].strip() for line in sanitized_lines if ':' in line}
99
-
100
- for field in required_fields - found_fields:
101
- if field == 'tags':
102
- sanitized_lines.extend(['tags:', ' - default'])
103
- else:
104
- sanitized_lines.append(f'{field}: "No {field} provided"')
105
-
106
- return '\n'.join(sanitized_lines)
107
-
108
- @dataclass
109
- class Endpoint:
110
- id: int
111
- url: str
112
- busy: bool = False
113
- last_used: float = 0
114
- error_count: int = 0
115
- error_until: float = 0 # Timestamp until which this endpoint is considered in error state
116
-
117
- class EndpointManager:
118
- def __init__(self):
119
- self.endpoints: List[Endpoint] = []
120
- self.lock = Lock()
121
- self.initialize_endpoints()
122
- self.last_used_index = -1 # Track the last used endpoint for round-robin
123
-
124
- def initialize_endpoints(self):
125
- """Initialize the list of endpoints"""
126
- for i, url in enumerate(VIDEO_ROUND_ROBIN_ENDPOINT_URLS):
127
- endpoint = Endpoint(id=i + 1, url=url)
128
- self.endpoints.append(endpoint)
129
-
130
- def _get_next_free_endpoint(self):
131
- """Get the next available non-busy endpoint, or oldest endpoint if all are busy"""
132
- current_time = time.time()
133
-
134
- # First priority: Get any non-busy and non-error endpoint
135
- free_endpoints = [
136
- ep for ep in self.endpoints
137
- if not ep.busy and current_time > ep.error_until
138
- ]
139
-
140
- if free_endpoints:
141
- # Return the least recently used free endpoint
142
- return min(free_endpoints, key=lambda ep: ep.last_used)
143
-
144
- # Second priority: If all busy/error, use round-robin but skip error endpoints
145
- tried_count = 0
146
- next_index = self.last_used_index
147
-
148
- while tried_count < len(self.endpoints):
149
- next_index = (next_index + 1) % len(self.endpoints)
150
- tried_count += 1
151
-
152
- # If endpoint is not in error state, use it
153
- if current_time > self.endpoints[next_index].error_until:
154
- self.last_used_index = next_index
155
- return self.endpoints[next_index]
156
-
157
- # If all endpoints are in error state, use the one with earliest error expiry
158
- self.last_used_index = next_index
159
- return min(self.endpoints, key=lambda ep: ep.error_until)
160
-
161
- @asynccontextmanager
162
- async def get_endpoint(self, max_wait_time: int = 10):
163
- """Get the next available endpoint using a context manager"""
164
- start_time = time.time()
165
- endpoint = None
166
-
167
- try:
168
- while True:
169
- if time.time() - start_time > max_wait_time:
170
- raise TimeoutError(f"Could not acquire an endpoint within {max_wait_time} seconds")
171
-
172
- async with self.lock:
173
- # Get the next available endpoint using our selection strategy
174
- endpoint = self._get_next_free_endpoint()
175
-
176
- # Mark it as busy
177
- endpoint.busy = True
178
- endpoint.last_used = time.time()
179
- #logger.info(f"Using endpoint {endpoint.id} (busy: {endpoint.busy}, last used: {endpoint.last_used})")
180
- break
181
-
182
- yield endpoint
183
-
184
- finally:
185
- if endpoint:
186
- async with self.lock:
187
- endpoint.busy = False
188
- endpoint.last_used = time.time()
189
- # We don't need to put back into queue - our strategy now picks directly from the list
190
-
191
- class ChatRoom:
192
- def __init__(self):
193
- self.messages = []
194
- self.connected_clients = set()
195
- self.max_history = 100
196
-
197
- def add_message(self, message):
198
- self.messages.append(message)
199
- if len(self.messages) > self.max_history:
200
- self.messages.pop(0)
201
-
202
- def get_recent_messages(self, limit=50):
203
- return self.messages[-limit:]
204
-
205
- class VideoGenerationAPI:
206
- def __init__(self):
207
- self.inference_client = InferenceClient(token=HF_TOKEN)
208
- self.hf_api = HfApi(token=HF_TOKEN)
209
- self.endpoint_manager = EndpointManager()
210
- self.active_requests: Dict[str, asyncio.Future] = {}
211
- self.chat_rooms = defaultdict(ChatRoom)
212
- self.video_events: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
213
- self.event_history_limit = 50
214
- # Cache for user roles to avoid repeated API calls
215
- self.user_role_cache: Dict[str, Dict[str, Any]] = {}
216
- # Cache expiration time (10 minutes)
217
- self.cache_expiration = 600
218
-
219
-
220
- def _add_event(self, video_id: str, event: Dict[str, Any]):
221
- """Add an event to the video's history and maintain the size limit"""
222
- events = self.video_events[video_id]
223
- events.append(event)
224
- if len(events) > self.event_history_limit:
225
- events.pop(0)
226
-
227
- async def validate_user_token(self, token: str) -> UserRole:
228
- """
229
- Validates a Hugging Face token and determines the user's role.
230
-
231
- Returns one of:
232
- - 'anon': Anonymous user (no token or invalid token)
233
- - 'normal': Standard Hugging Face user
234
- - 'pro': Hugging Face Pro user
235
- - 'admin': Admin user (username in ADMIN_ACCOUNTS)
236
- """
237
- # If no token is provided, the user is anonymous
238
- if not token:
239
- return 'anon'
240
-
241
- # Check if we have a cached result for this token
242
- current_time = time.time()
243
- if token in self.user_role_cache:
244
- cached_data = self.user_role_cache[token]
245
- # If the cache is still valid
246
- if current_time - cached_data['timestamp'] < self.cache_expiration:
247
- logger.info(f"Using cached user role: {cached_data['role']}")
248
- return cached_data['role']
249
-
250
- # No valid cache, need to check the token with the HF API
251
- try:
252
- # Use HF API to validate the token and get user info
253
- logger.info("Validating Hugging Face token...")
254
-
255
- # Run in executor to avoid blocking the event loop
256
- user_info = await asyncio.get_event_loop().run_in_executor(
257
- None,
258
- lambda: self.hf_api.whoami(token=token)
259
- )
260
-
261
- # Handle both object and dict response formats from whoami
262
- username = user_info.get('name') if isinstance(user_info, dict) else getattr(user_info, 'name', None)
263
- is_pro = user_info.get('is_pro') if isinstance(user_info, dict) else getattr(user_info, 'is_pro', False)
264
-
265
- if not username:
266
- logger.error(f"Could not determine username from user_info: {user_info}")
267
- return 'anon'
268
-
269
- logger.info(f"Token valid for user: {username}")
270
-
271
- # Determine the user role based on the information
272
- user_role: UserRole
273
-
274
- # Check if the user is an admin
275
- if username in ADMIN_ACCOUNTS:
276
- user_role = 'admin'
277
- # Check if the user has a pro account
278
- elif is_pro:
279
- user_role = 'pro'
280
- else:
281
- user_role = 'normal'
282
-
283
- # Cache the result
284
- self.user_role_cache[token] = {
285
- 'role': user_role,
286
- 'timestamp': current_time,
287
- 'username': username
288
- }
289
-
290
- return user_role
291
-
292
- except Exception as e:
293
- logger.error(f"Failed to validate Hugging Face token: {str(e)}")
294
- # If validation fails, the user is treated as anonymous
295
- return 'anon'
296
-
297
- async def download_video(self, url: str) -> bytes:
298
- """Download video file from URL and return bytes"""
299
- async with ClientSession() as session:
300
- async with session.get(url) as response:
301
- if response.status != 200:
302
- raise Exception(f"Failed to download video: HTTP {response.status}")
303
- return await response.read()
304
-
305
- async def search_video(self, query: str, attempt_count: int = 0) -> Optional[dict]:
306
- """Generate a single search result using HF text generation"""
307
- # Maximum number of attempts to generate a description without placeholder tags
308
- max_attempts = 2
309
- current_attempt = attempt_count
310
- # Use a random temperature between 0.68 and 0.72 to generate more diverse results
311
- # and prevent duplicate results from successive calls with the same prompt
312
- temperature = random.uniform(0.68, 0.72)
313
-
314
- while current_attempt <= max_attempts:
315
- prompt = f"""# Instruction
316
- Your response MUST be a YAML object containing a title and description, consistent with what we can find on a video sharing platform.
317
- Format your YAML response with only those fields: "title" (a short string) and "description" (string caption of the scene). Do not add any other field.
318
- In the description field, describe in a very synthetic way the visuals of the first shot (first scene), eg "<STYLE>, medium close-up shot, high angle view. In the foreground a <OPTIONAL AGE> <OPTIONAL GENDER> <CHARACTERS> <ACTIONS>. In the background <DESCRIBE LOCATION, BACKGROUND CHARACTERS, OBJECTS ETC>. The scene is lit by <LIGHTING> <WEATHER>". This is just an example! you MUST replace the <TAGS>!!.
319
- Don't forget to replace <STYLE> etc, by the actual fields!!
320
- For the style, be creative, for instance you can use anything like a "documentary footage", "japanese animation", "movie scene", "tv series", "tv show", "security footage" etc.
321
- If the user ask for something specific eg "movie screencap", "movie scene", "documentary footage" "animation" as a style etc.
322
- Keep it minimalist but still descriptive, don't use bullets points, use simple words, go to the essential to describe style (cinematic, documentary footage, 3D rendering..), camera modes and angles, characters, age, gender, action, location, lighting, country, costume, time, weather, textures, color palette.. etc). Write about 80 words, and use between 2 and 3 sentences.
323
- The most import part is to describe the actions and movements in the scene, so don't forget that!
324
- Don't describe sound, so ever say things like "atmospheric music playing in the background".
325
- Instead describe the visual elements we can see in the background, be precise, (if there are anything, cars, objects, people, bricks, birds, clouds, trees, leaves or grass then say it so etc).
326
- Make the result unique and different from previous search results. ONLY RETURN YAML AND WITH ENGLISH CONTENT, NOT CHINESE - DO NOT ADD ANY OTHER COMMENT!
327
-
328
- # Context
329
- This is attempt {current_attempt}.
330
-
331
- # Input
332
- Describe the first scene/shot for: "{query}".
333
-
334
- # Output
335
-
336
- ```yaml
337
- title: \""""
338
-
339
- try:
340
- response = await asyncio.get_event_loop().run_in_executor(
341
- None,
342
- lambda: self.inference_client.text_generation(
343
- prompt,
344
- model=TEXT_MODEL,
345
- max_new_tokens=200,
346
- temperature=temperature
347
- )
348
- )
349
-
350
- response_text = re.sub(r'^\s*\.\s*\n', '', f"title: \"{response.strip()}")
351
- sanitized_yaml = sanitize_yaml_response(response_text)
352
-
353
- try:
354
- result = yaml.safe_load(sanitized_yaml)
355
- except yaml.YAMLError as e:
356
- logger.error(f"YAML parsing failed: {str(e)}")
357
- result = None
358
-
359
- if not result or not isinstance(result, dict):
360
- logger.error(f"Invalid result format: {result}")
361
- current_attempt += 1
362
- temperature = random.uniform(0.68, 0.72) # Try with different random temperature on next attempt
363
- continue
364
-
365
- # Extract fields with defaults
366
- title = str(result.get('title', '')).strip() or 'Untitled Video'
367
- description = str(result.get('description', '')).strip() or 'No description available'
368
-
369
- # Check if the description still contains placeholder tags like <LOCATION>, <GENDER>, etc.
370
- if re.search(r'<[A-Z_]+>', description):
371
- #logger.warning(f"Description still contains placeholder tags: {description}")
372
- if current_attempt < max_attempts:
373
- # Try again with a different random temperature
374
- current_attempt += 1
375
- temperature = random.uniform(0.68, 0.72)
376
- continue
377
- else:
378
- # If we've reached max attempts, use the title as description
379
- description = title
380
-
381
- # Return valid result with all required fields
382
- return {
383
- 'id': str(uuid.uuid4()),
384
- 'title': title,
385
- 'description': description,
386
- 'thumbnailUrl': '',
387
- 'videoUrl': '',
388
-
389
- # not really used yet, maybe one day if we pre-generate or store content
390
- 'isLatent': True,
391
-
392
- 'useFixedSeed': "webcam" in description.lower(),
393
-
394
- 'seed': generate_seed(),
395
- 'views': 0,
396
- 'tags': []
397
- }
398
-
399
- except Exception as e:
400
- logger.error(f"Search video generation failed: {str(e)}")
401
- current_attempt += 1
402
- temperature = random.uniform(0.68, 0.72) # Try with different random temperature on next attempt
403
-
404
- # If all attempts failed, return a simple result with title only
405
- return {
406
- 'id': str(uuid.uuid4()),
407
- 'title': f"Video about {query}",
408
- 'description': f"Video about {query}",
409
- 'thumbnailUrl': '',
410
- 'videoUrl': '',
411
- 'isLatent': True,
412
- 'useFixedSeed': "query" in description.lower(),
413
- 'seed': generate_seed(),
414
- 'views': 0,
415
- 'tags': []
416
- }
417
-
418
- # The generate_thumbnail function has been removed because we now use
419
- # generate_video_thumbnail for all thumbnails, which generates a video clip
420
- # instead of a static image
421
-
422
- async def generate_caption(self, title: str, description: str) -> str:
423
- """Generate detailed caption using HF text generation"""
424
- try:
425
- prompt = f"""Generate a detailed story for a video named: "{title}"
426
- Visual description of the video: {description}.
427
- Instructions: Write the story summary, including the plot, action, what should happen.
428
- Make it around 200-300 words long.
429
- A video can be anything from a tutorial, webcam, trailer, movie, live stream etc."""
430
-
431
- response = await asyncio.get_event_loop().run_in_executor(
432
- None,
433
- lambda: self.inference_client.text_generation(
434
- prompt,
435
- model=TEXT_MODEL,
436
- max_new_tokens=180,
437
- temperature=0.7
438
- )
439
- )
440
-
441
- if "Caption: " in response:
442
- response = response.replace("Caption: ", "")
443
-
444
- chunks = f" {response} ".split(". ")
445
- if len(chunks) > 1:
446
- text = ". ".join(chunks[:-1])
447
- else:
448
- text = response
449
-
450
- return text.strip()
451
- except Exception as e:
452
- logger.error(f"Error generating caption: {str(e)}")
453
- return ""
454
-
455
- async def simulate(self, original_title: str, original_description: str,
456
- current_description: str, condensed_history: str,
457
- evolution_count: int = 0, chat_messages: str = '') -> dict:
458
- """
459
- Simulate a video by evolving its description to create a dynamic narrative.
460
-
461
- Args:
462
- original_title: The original video title
463
- original_description: The original video description
464
- current_description: The current description (last evolved or original if first evolution)
465
- condensed_history: A condensed summary of previous scene developments
466
- evolution_count: How many times the simulation has already evolved
467
- chat_messages: Chat messages from users to incorporate into the simulation
468
-
469
- Returns:
470
- A dictionary containing the evolved description and updated condensed history
471
- """
472
- try:
473
- # Determine if this is the first simulation
474
- is_first_simulation = evolution_count == 0 or not condensed_history
475
-
476
- logger.info(f"simulate(): is_first_simulation={is_first_simulation}")
477
-
478
- # Create an appropriate prompt based on whether this is the first simulation
479
- chat_section = ""
480
- if chat_messages:
481
- chat_section = f"""
482
- People are watching this content right now and have shared their thoughts. Like a game master, please take their feedback as input to adjust the story and/or the scene. Here are their messages:
483
-
484
- {chat_messages}
485
- """
486
-
487
- if is_first_simulation:
488
- prompt = f"""You are tasked with evolving the narrative for a video titled: "{original_title}"
489
-
490
- Original description:
491
- {original_description}
492
- {chat_section}
493
-
494
- Instructions:
495
- 1. Imagine the next logical scene or development that would follow this description.
496
- 2. Create a compelling new description (200-300 words) that builds on the original but introduces new elements, developments, or perspectives.
497
- 3. Maintain the original style, tone, and setting.
498
- 4. If viewers have shared messages, consider their input and incorporate relevant suggestions or reactions into your narrative evolution.
499
- 5. Also create a brief "scene history" (50-75 words) that summarizes what has happened so far.
500
-
501
- Return your response in this format:
502
- EVOLVED_DESCRIPTION: [your new evolved description here]
503
- CONDENSED_HISTORY: [your scene history summary]"""
504
- else:
505
- prompt = f"""You are tasked with continuing to evolve the narrative for a video titled: "{original_title}"
506
-
507
- Original description:
508
- {original_description}
509
-
510
- Condensed history of scenes so far:
511
- {condensed_history}
512
-
513
- Current description (most recent scene):
514
- {current_description}
515
- {chat_section}
516
-
517
- Instructions:
518
- 1. Imagine the next logical scene or development that would follow the current description.
519
- 2. Create a compelling new description (200-300 words) that builds on the narrative but introduces new elements, developments, or perspectives.
520
- 3. Maintain consistency with the previous scenes while advancing the story.
521
- 4. If viewers have shared messages, consider their input and incorporate relevant suggestions or reactions into your narrative evolution.
522
- 5. Also update the condensed history (50-75 words) to include this new development.
523
-
524
- Return your response in this format:
525
- EVOLVED_DESCRIPTION: [your new evolved description here]
526
- CONDENSED_HISTORY: [your updated scene history summary]"""
527
-
528
- # Generate the evolved description
529
- response = await asyncio.get_event_loop().run_in_executor(
530
- None,
531
- lambda: self.inference_client.text_generation(
532
- prompt,
533
- model=TEXT_MODEL,
534
- max_new_tokens=200,
535
- temperature=0.7
536
- )
537
- )
538
-
539
- # Extract the evolved description and condensed history from the response
540
- evolved_description = ""
541
- new_condensed_history = ""
542
-
543
- # Parse the response
544
- if "EVOLVED_DESCRIPTION:" in response and "CONDENSED_HISTORY:" in response:
545
- parts = response.split("CONDENSED_HISTORY:")
546
- if len(parts) >= 2:
547
- desc_part = parts[0].strip()
548
- if "EVOLVED_DESCRIPTION:" in desc_part:
549
- evolved_description = desc_part.split("EVOLVED_DESCRIPTION:", 1)[1].strip()
550
- new_condensed_history = parts[1].strip()
551
-
552
- # If parsing failed, use some fallbacks
553
- if not evolved_description:
554
- evolved_description = current_description
555
- logger.warning(f"Failed to parse evolved description, using current description as fallback")
556
-
557
- if not new_condensed_history and condensed_history:
558
- new_condensed_history = condensed_history
559
- logger.warning(f"Failed to parse condensed history, using current history as fallback")
560
- elif not new_condensed_history:
561
- new_condensed_history = f"The video begins with {original_title}: {original_description[:100]}..."
562
-
563
- return {
564
- "evolved_description": evolved_description,
565
- "condensed_history": new_condensed_history
566
- }
567
-
568
- except Exception as e:
569
- logger.error(f"Error simulating video: {str(e)}")
570
- return {
571
- "evolved_description": current_description,
572
- "condensed_history": condensed_history or f"The video shows {original_title}."
573
- }
574
-
575
-
576
- def get_config_value(self, role: UserRole, field: str, options: dict = None) -> Any:
577
- """
578
- Get the appropriate config value for a user role.
579
-
580
- Args:
581
- role: The user role ('anon', 'normal', 'pro', 'admin')
582
- field: The config field name to retrieve
583
- options: Optional user-provided options that may override defaults
584
-
585
- Returns:
586
- The config value appropriate for the user's role with respect to
587
- min/max boundaries and user overrides.
588
- """
589
- # Select the appropriate config based on user role
590
- if role == 'admin':
591
- config = CONFIG_FOR_ADMIN_HF_USERS
592
- elif role == 'pro':
593
- config = CONFIG_FOR_PRO_HF_USERS
594
- elif role == 'normal':
595
- config = CONFIG_FOR_STANDARD_HF_USERS
596
- else: # Anonymous users
597
- config = CONFIG_FOR_ANONYMOUS_USERS
598
-
599
- # Get the default value for this field from the config
600
- default_value = config.get(f"default_{field}", None)
601
-
602
- # For fields that have min/max bounds
603
- min_field = f"min_{field}"
604
- max_field = f"max_{field}"
605
-
606
- # Check if min/max constraints exist for this field
607
- has_constraints = min_field in config or max_field in config
608
-
609
- if not has_constraints:
610
- # For fields without constraints, just return the value from config
611
- return default_value
612
-
613
- # Get min and max values from config (if they exist)
614
- min_value = config.get(min_field, None)
615
- max_value = config.get(max_field, None)
616
-
617
- # If user provided options with this field
618
- if options and field in options:
619
- user_value = options[field]
620
-
621
- # Apply constraints if they exist
622
- if min_value is not None and user_value < min_value:
623
- return min_value
624
- if max_value is not None and user_value > max_value:
625
- return max_value
626
-
627
- # If within bounds, use the user's value
628
- return user_value
629
-
630
- # If no user value, return the default
631
- return default_value
632
-
633
- async def _generate_clip_prompt(self, video_id: str, title: str, description: str) -> str:
634
- """Generate a new prompt for the next clip based on event history"""
635
- events = self.video_events.get(video_id, [])
636
- events_json = "\n".join(json.dumps(event) for event in events)
637
-
638
- prompt = f"""# Context and task
639
- Please write the caption for a new clip.
640
-
641
- # Instructions
642
- 1. Consider the video context and recent events
643
- 2. Create a natural progression from previous clips
644
- 3. Take into account user suggestions (chat messages) into the scene
645
- 4. Don't generate hateful, political, violent or sexual content
646
- 5. Keep visual consistency with previous clips (in most cases you should repeat the same exact description of the location, characters etc but only change a few elements. If this is a webcam scenario, don't touch the camera orientation or focus)
647
- 6. Return ONLY the caption text, no additional formatting or explanation
648
- 7. Write in English, about 200 words.
649
- 8. Keep the visual style consistant, but content as well (repeat the style, character, locations, appearance etc.. across scenes, when it makes sense).
650
- 8. Your caption must describe visual elements of the scene in details, including: camera angle and focus, people's appearance, age, look, costumes, clothes, the location visual characteristics and geometry, lighting, action, objects, weather, textures, lighting.
651
-
652
- # Examples
653
- Here is a demo scenario, with fake data:
654
- {{"time": "2024-11-29T13:36:15Z", "event": "new_stream_clip", "caption": "webcam view of a beautiful park, squirrels are playing in the lush grass, blablabla etc... (rest omitted for brevity)"}}
655
- {{"time": "2024-11-29T13:36:20Z", "event": "new_chat_message", "username": "MonkeyLover89", "data": "hi"}}
656
- {{"time": "2024-11-29T13:36:25Z", "event": "new_chat_message", "username": "MonkeyLover89", "data": "more squirrels plz"}}
657
- {{"time": "2024-11-29T13:36:26Z", "event": "new_stream_clip", "caption": "webcam view of a beautiful park, a lot of squirrels are playing in the lush grass, blablabla etc... (rest omitted for brevity)"}}
658
-
659
- # Real scenario and data
660
-
661
- We are inside a video titled "{title}"
662
- The video is described by: "{description}".
663
- Here is a summary of the {len(events)} most recent events:
664
- {events_json}
665
-
666
- # Your response
667
- Your caption:"""
668
-
669
- try:
670
- response = await asyncio.get_event_loop().run_in_executor(
671
- None,
672
- lambda: self.inference_client.text_generation(
673
- prompt,
674
- model=TEXT_MODEL,
675
- max_new_tokens=200,
676
- temperature=0.7
677
- )
678
- )
679
-
680
- # Clean up the response
681
- caption = response.strip()
682
- if caption.lower().startswith("caption:"):
683
- caption = caption[8:].strip()
684
-
685
- return caption
686
-
687
- except Exception as e:
688
- logger.error(f"Error generating clip prompt: {str(e)}")
689
- # Fallback to original description if prompt generation fails
690
- return description
691
-
692
- async def generate_video_thumbnail(self, title: str, description: str, video_prompt_prefix: str, options: dict, user_role: UserRole = 'anon') -> str:
693
- """
694
- Generate a short, low-resolution video thumbnail for search results and previews.
695
- Optimized for quick generation and low resource usage.
696
- """
697
- video_id = options.get('video_id', str(uuid.uuid4()))
698
- seed = options.get('seed', generate_seed())
699
- request_id = str(uuid.uuid4())[:8] # Generate a short ID for logging
700
-
701
- logger.info(f"[{request_id}] Starting video thumbnail generation for video_id: {video_id}")
702
- logger.info(f"[{request_id}] Title: '{title}', User role: {user_role}")
703
-
704
- # Create a more concise prompt for the thumbnail
705
- clip_caption = f"{video_prompt_prefix} - {title.strip()}"
706
-
707
- # Add the thumbnail generation to event history
708
- self._add_event(video_id, {
709
- "time": datetime.datetime.utcnow().isoformat() + "Z",
710
- "event": "thumbnail_generation",
711
- "caption": clip_caption,
712
- "seed": seed,
713
- "request_id": request_id
714
- })
715
-
716
- # Use a shorter prompt for thumbnails
717
- prompt = f"{clip_caption}, {POSITIVE_PROMPT_SUFFIX}"
718
- logger.info(f"[{request_id}] Using prompt: '{prompt}'")
719
-
720
- # Specialized configuration for thumbnails - smaller size, single frame
721
- width = 512 # Reduced size for thumbnails
722
- height = 288 # 16:9 aspect ratio
723
- num_frames = THUMBNAIL_FRAMES # Just one frame for static thumbnail
724
- num_inference_steps = 4 # Fewer steps for faster generation
725
- frame_rate = 25 # Standard frame rate
726
-
727
- # Optionally override with options if specified
728
- width = options.get('width', width)
729
- height = options.get('height', height)
730
- num_frames = options.get('num_frames', num_frames)
731
- num_inference_steps = options.get('num_inference_steps', num_inference_steps)
732
- frame_rate = options.get('frame_rate', frame_rate)
733
-
734
- logger.info(f"[{request_id}] Configuration: width={width}, height={height}, frames={num_frames}, steps={num_inference_steps}, fps={frame_rate}")
735
-
736
- # Add thumbnail-specific tag to help debugging and metrics
737
- options['thumbnail'] = True
738
-
739
- # Check for available endpoints before attempting generation
740
- available_endpoints = sum(1 for ep in self.endpoint_manager.endpoints
741
- if not ep.busy and time.time() > ep.error_until)
742
- logger.info(f"[{request_id}] Available endpoints: {available_endpoints}/{len(self.endpoint_manager.endpoints)}")
743
-
744
- if available_endpoints == 0:
745
- logger.error(f"[{request_id}] No available endpoints for thumbnail generation")
746
- return ""
747
-
748
- # Use the same logic as regular video generation but with thumbnail settings
749
- try:
750
- # logger.info(f"[{request_id}] Generating thumbnail for video {video_id} with seed {seed}")
751
-
752
- start_time = time.time()
753
- # Rest of thumbnail generation logic same as regular video but with optimized settings
754
- result = await self._generate_video_content(
755
- prompt=prompt,
756
- negative_prompt=options.get('negative_prompt', NEGATIVE_PROMPT),
757
- width=width,
758
- height=height,
759
- num_frames=num_frames,
760
- num_inference_steps=num_inference_steps,
761
- frame_rate=frame_rate,
762
- seed=seed,
763
- options=options,
764
- user_role=user_role
765
- )
766
- duration = time.time() - start_time
767
-
768
- if result:
769
- data_length = len(result)
770
- logger.info(f"[{request_id}] Successfully generated thumbnail in {duration:.2f}s, data length: {data_length} chars")
771
- return result
772
- else:
773
- logger.error(f"[{request_id}] Empty result returned from video generation")
774
- return ""
775
-
776
- except Exception as e:
777
- logger.error(f"[{request_id}] Error generating thumbnail: {e}")
778
- if hasattr(e, "__traceback__"):
779
- import traceback
780
- logger.error(f"[{request_id}] Traceback: {traceback.format_exc()}")
781
- return "" # Return empty string instead of raising to avoid crashes
782
-
783
- async def generate_video(self, title: str, description: str, video_prompt_prefix: str, options: dict, user_role: UserRole = 'anon') -> str:
784
- """Generate video using available space from pool"""
785
- video_id = options.get('video_id', str(uuid.uuid4()))
786
-
787
- # Generate a new prompt based on event history
788
- #clip_caption = await self._generate_clip_prompt(video_id, title, description)
789
- clip_caption = f"{video_prompt_prefix} - {title.strip()} - {description.strip()}"
790
-
791
- # Add the new clip to event history
792
- self._add_event(video_id, {
793
- "time": datetime.datetime.utcnow().isoformat() + "Z",
794
- "event": "new_stream_clip",
795
- "caption": clip_caption
796
- })
797
-
798
- # Use the generated caption as the prompt
799
- prompt = f"{clip_caption}, {POSITIVE_PROMPT_SUFFIX}"
800
-
801
- # Get the config values based on user role
802
- width = self.get_config_value(user_role, 'clip_width', options)
803
- height = self.get_config_value(user_role, 'clip_height', options)
804
- num_frames = self.get_config_value(user_role, 'num_frames', options)
805
- num_inference_steps = self.get_config_value(user_role, 'num_inference_steps', options)
806
- frame_rate = self.get_config_value(user_role, 'clip_framerate', options)
807
-
808
- # Get orientation from options
809
- orientation = options.get('orientation', 'LANDSCAPE')
810
-
811
- # Adjust width and height based on orientation if needed
812
- if orientation == 'PORTRAIT' and width > height:
813
- # Swap width and height for portrait orientation
814
- width, height = height, width
815
- # logger.info(f"Orientation: {orientation}, swapped dimensions to width={width}, height={height}")
816
- elif orientation == 'LANDSCAPE' and height > width:
817
- # Swap height and width for landscape orientation
818
- height, width = width, height
819
- # logger.info(f"generate_video() Orientation: {orientation}, swapped dimensions to width={width}, height={height}, steps={num_inference_steps}, fps={frame_rate} | role: {user_role}")
820
- else:
821
- # logger.info(f"generate_video() Orientation: {orientation}, using original dimensions width={width}, height={height}, steps={num_inference_steps}, fps={frame_rate} | role: {user_role}")
822
- pass
823
-
824
- # Generate the video with standard settings
825
- return await self._generate_video_content(
826
- prompt=prompt,
827
- negative_prompt=options.get('negative_prompt', NEGATIVE_PROMPT),
828
- width=width,
829
- height=height,
830
- num_frames=num_frames,
831
- num_inference_steps=num_inference_steps,
832
- frame_rate=frame_rate,
833
- seed=options.get('seed', 42),
834
- options=options,
835
- user_role=user_role
836
- )
837
-
838
- async def _generate_video_content(self, prompt: str, negative_prompt: str, width: int,
839
- height: int, num_frames: int, num_inference_steps: int,
840
- frame_rate: int, seed: int, options: dict, user_role: UserRole) -> str:
841
- """
842
- Internal method to generate video content with specific parameters.
843
- Used by both regular video generation and thumbnail generation.
844
- """
845
- is_thumbnail = options.get('thumbnail', False)
846
- request_id = options.get('request_id', str(uuid.uuid4())[:8]) # Get or generate request ID
847
- video_id = options.get('video_id', 'unknown')
848
-
849
- # logger.info(f"[{request_id}] Generating {'thumbnail' if is_thumbnail else 'video'} for video {video_id} with seed {seed}")
850
-
851
- json_payload = {
852
- "inputs": {
853
- "prompt": prompt,
854
- },
855
- "parameters": {
856
- # ------------------- settings for LTX-Video -----------------------
857
- "negative_prompt": negative_prompt,
858
- "width": width,
859
- "height": height,
860
- "num_frames": num_frames,
861
- "num_inference_steps": num_inference_steps,
862
- "guidance_scale": options.get('guidance_scale', GUIDANCE_SCALE),
863
- "seed": seed,
864
-
865
- # ------------------- settings for Varnish -----------------------
866
- "double_num_frames": False, # <- False for real-time generation
867
- "fps": frame_rate,
868
- "super_resolution": False, # <- False for real-time generation
869
- "grain_amount": 0, # No film grain (on low-res, low-quality generation the effects aren't worth it + it adds weight to the MP4 payload)
870
- }
871
- }
872
-
873
- # Add thumbnail flag to help with metrics and debugging
874
- if is_thumbnail:
875
- json_payload["metadata"] = {
876
- "is_thumbnail": True,
877
- "thumbnail_version": "1.0",
878
- "request_id": request_id
879
- }
880
-
881
- # logger.info(f"[{request_id}] Waiting for an available endpoint...")
882
- async with self.endpoint_manager.get_endpoint() as endpoint:
883
- # logger.info(f"[{request_id}] Using endpoint {endpoint.id} for generation")
884
-
885
- try:
886
- async with ClientSession() as session:
887
- #logger.info(f"[{request_id}] Sending request to endpoint {endpoint.id}: {endpoint.url}")
888
- start_time = time.time()
889
-
890
- # Proceed with actual request
891
- async with session.post(
892
- endpoint.url,
893
- headers={
894
- "Accept": "application/json",
895
- "Authorization": f"Bearer {HF_TOKEN}",
896
- "Content-Type": "application/json",
897
- "X-Request-ID": request_id # Add request ID to headers
898
- },
899
- json=json_payload,
900
- timeout=12 # Extended timeout for thumbnails (was 8s)
901
- ) as response:
902
- request_duration = time.time() - start_time
903
- #logger.info(f"[{request_id}] Received response from endpoint {endpoint.id} in {request_duration:.2f}s: HTTP {response.status}")
904
-
905
- if response.status != 200:
906
- error_text = await response.text()
907
- logger.error(f"[{request_id}] Failed response: {error_text}")
908
- # Mark endpoint as in error state
909
- await self._mark_endpoint_error(endpoint)
910
- if "paused" in error_text:
911
- logger.error(f"[{request_id}] Endpoint is paused")
912
- return ""
913
- raise Exception(f"Video generation failed: HTTP {response.status} - {error_text}")
914
-
915
- result = await response.json()
916
- #logger.info(f"[{request_id}] Successfully parsed JSON response")
917
-
918
- if "error" in result:
919
- error_msg = result['error']
920
- logger.error(f"[{request_id}] Error in response: {error_msg}")
921
- # Mark endpoint as in error state
922
- await self._mark_endpoint_error(endpoint)
923
- if "paused" in str(error_msg).lower():
924
- logger.error(f"[{request_id}] Endpoint is paused")
925
- return ""
926
- raise Exception(f"Video generation failed: {error_msg}")
927
-
928
- video_data_uri = result.get("video")
929
- if not video_data_uri:
930
- logger.error(f"[{request_id}] No video data in response")
931
- # Mark endpoint as in error state
932
- await self._mark_endpoint_error(endpoint)
933
- raise Exception("No video data in response")
934
-
935
- # Get data size
936
- data_size = len(video_data_uri)
937
- #logger.info(f"[{request_id}] Received video data: {data_size} chars")
938
-
939
- # Reset error count on successful call
940
- endpoint.error_count = 0
941
- endpoint.error_until = 0
942
-
943
- return video_data_uri
944
-
945
- except asyncio.TimeoutError:
946
- # Handle timeout specifically
947
- logger.error(f"[{request_id}] Timeout occurred after {time.time() - start_time:.2f}s")
948
- await self._mark_endpoint_error(endpoint, is_timeout=True)
949
- return ""
950
- except Exception as e:
951
- # Handle all other exceptions
952
- logger.error(f"[{request_id}] Exception during video generation: {str(e)}")
953
- if not isinstance(e, asyncio.TimeoutError): # Already handled above
954
- await self._mark_endpoint_error(endpoint)
955
- return ""
956
-
957
- async def _mark_endpoint_error(self, endpoint: Endpoint, is_timeout: bool = False):
958
- """Mark an endpoint as being in error state with exponential backoff"""
959
- async with self.endpoint_manager.lock:
960
- endpoint.error_count += 1
961
-
962
- # Calculate backoff time exponentially based on error count
963
- # Start with 15 seconds, then 30, 60, etc. up to a max of 5 minutes
964
- # Using shorter backoffs since generation should be fast
965
- backoff_seconds = min(15 * (2 ** (endpoint.error_count - 1)), 300)
966
-
967
- # Add extra backoff for timeouts which are more indicative of serious issues
968
- if is_timeout:
969
- backoff_seconds *= 2
970
-
971
- endpoint.error_until = time.time() + backoff_seconds
972
-
973
- logger.warning(
974
- f"Endpoint {endpoint.id} marked as in error state (count: {endpoint.error_count}, "
975
- f"unavailable until: {datetime.datetime.fromtimestamp(endpoint.error_until).strftime('%H:%M:%S')})"
976
- )
977
-
978
-
979
- async def handle_chat_message(self, data: dict, ws: web.WebSocketResponse) -> dict:
980
- """Process and broadcast a chat message"""
981
- video_id = data.get('videoId')
982
- request_id = data.get('requestId')
983
-
984
- if not video_id:
985
- return {
986
- 'action': 'chat_message',
987
- 'requestId': request_id,
988
- 'success': False,
989
- 'error': 'No video ID provided'
990
- }
991
-
992
- # Add chat message to event history
993
- self._add_event(video_id, {
994
- "time": datetime.datetime.utcnow().isoformat() + "Z",
995
- "event": "new_chat_message",
996
- "username": data.get('username', 'Anonymous'),
997
- "data": data.get('content', '')
998
- })
999
-
1000
- room = self.chat_rooms[video_id]
1001
- message_data = {k: v for k, v in data.items() if k != '_ws'}
1002
- room.add_message(message_data)
1003
-
1004
- for client in room.connected_clients:
1005
- if client != ws:
1006
- try:
1007
- await client.send_json({
1008
- 'action': 'chat_message',
1009
- 'broadcast': True,
1010
- **message_data
1011
- })
1012
- except Exception as e:
1013
- logger.error(f"Failed to broadcast to client: {e}")
1014
- room.connected_clients.remove(client)
1015
-
1016
- return {
1017
- 'action': 'chat_message',
1018
- 'requestId': request_id,
1019
- 'success': True,
1020
- 'message': message_data
1021
- }
1022
-
1023
- async def handle_join_chat(self, data: dict, ws: web.WebSocketResponse) -> dict:
1024
- """Handle a request to join a chat room"""
1025
- video_id = data.get('videoId')
1026
- request_id = data.get('requestId')
1027
-
1028
- if not video_id:
1029
- return {
1030
- 'action': 'join_chat',
1031
- 'requestId': request_id,
1032
- 'success': False,
1033
- 'error': 'No video ID provided'
1034
- }
1035
-
1036
- room = self.chat_rooms[video_id]
1037
- room.connected_clients.add(ws)
1038
- recent_messages = room.get_recent_messages()
1039
-
1040
- return {
1041
- 'action': 'join_chat',
1042
- 'requestId': request_id,
1043
- 'success': True,
1044
- 'messages': recent_messages
1045
- }
1046
-
1047
- async def handle_leave_chat(self, data: dict, ws: web.WebSocketResponse) -> dict:
1048
- """Handle a request to leave a chat room"""
1049
- video_id = data.get('videoId')
1050
- request_id = data.get('requestId')
1051
-
1052
- if not video_id:
1053
- return {
1054
- 'action': 'leave_chat',
1055
- 'requestId': request_id,
1056
- 'success': False,
1057
- 'error': 'No video ID provided'
1058
- }
1059
-
1060
- room = self.chat_rooms[video_id]
1061
- if ws in room.connected_clients:
1062
- room.connected_clients.remove(ws)
1063
-
1064
- return {
1065
- 'action': 'leave_chat',
1066
- 'requestId': request_id,
1067
- 'success': True
1068
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reference_example/api_metrics.py DELETED
@@ -1,185 +0,0 @@
1
- import time
2
- import logging
3
- import asyncio
4
- from collections import defaultdict
5
- from typing import Dict, List, Set, Optional
6
- import datetime
7
-
8
- logger = logging.getLogger(__name__)
9
-
10
- class MetricsTracker:
11
- """
12
- Tracks usage metrics across the API server.
13
- """
14
- def __init__(self):
15
- # Total metrics since server start
16
- self.total_requests = {
17
- 'chat': 0,
18
- 'video': 0,
19
- 'search': 0,
20
- 'other': 0,
21
- }
22
-
23
- # Per-user metrics
24
- self.user_metrics = defaultdict(lambda: {
25
- 'requests': {
26
- 'chat': 0,
27
- 'video': 0,
28
- 'search': 0,
29
- 'other': 0,
30
- },
31
- 'first_seen': time.time(),
32
- 'last_active': time.time(),
33
- 'role': 'anon'
34
- })
35
-
36
- # Rate limiting buckets (per minute)
37
- self.rate_limits = {
38
- 'anon': {
39
- 'video': 30,
40
- 'search': 45,
41
- 'chat': 90,
42
- 'other': 45
43
- },
44
- 'normal': {
45
- 'video': 60,
46
- 'search': 90,
47
- 'chat': 180,
48
- 'other': 90
49
- },
50
- 'pro': {
51
- 'video': 120,
52
- 'search': 180,
53
- 'chat': 300,
54
- 'other': 180
55
- },
56
- 'admin': {
57
- 'video': 240,
58
- 'search': 360,
59
- 'chat': 450,
60
- 'other': 360
61
- }
62
- }
63
-
64
- # Minute-based rate limiting buckets
65
- self.time_buckets = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
66
-
67
- # Lock for thread safety
68
- self.lock = asyncio.Lock()
69
-
70
- # Track concurrent sessions by IP
71
- self.ip_sessions = defaultdict(set)
72
-
73
- # Server start time
74
- self.start_time = time.time()
75
-
76
- async def record_request(self, user_id: str, ip: str, request_type: str, role: str):
77
- """Record a request for metrics and rate limiting"""
78
- async with self.lock:
79
- # Update total metrics
80
- if request_type in self.total_requests:
81
- self.total_requests[request_type] += 1
82
- else:
83
- self.total_requests['other'] += 1
84
-
85
- # Update user metrics
86
- user_data = self.user_metrics[user_id]
87
- user_data['last_active'] = time.time()
88
- user_data['role'] = role
89
-
90
- if request_type in user_data['requests']:
91
- user_data['requests'][request_type] += 1
92
- else:
93
- user_data['requests']['other'] += 1
94
-
95
- # Update time bucket for rate limiting
96
- current_minute = int(time.time() / 60)
97
- self.time_buckets[user_id][current_minute][request_type] += 1
98
-
99
- # Clean up old time buckets (keep only last 10 minutes)
100
- cutoff = current_minute - 10
101
- for minute in list(self.time_buckets[user_id].keys()):
102
- if minute < cutoff:
103
- del self.time_buckets[user_id][minute]
104
-
105
- def register_session(self, user_id: str, ip: str):
106
- """Register a new session for an IP address"""
107
- self.ip_sessions[ip].add(user_id)
108
-
109
- def unregister_session(self, user_id: str, ip: str):
110
- """Unregister a session when it disconnects"""
111
- if user_id in self.ip_sessions[ip]:
112
- self.ip_sessions[ip].remove(user_id)
113
- if not self.ip_sessions[ip]:
114
- del self.ip_sessions[ip]
115
-
116
- def get_session_count_for_ip(self, ip: str) -> int:
117
- """Get the number of active sessions for an IP address"""
118
- return len(self.ip_sessions.get(ip, set()))
119
-
120
- async def is_rate_limited(self, user_id: str, request_type: str, role: str) -> bool:
121
- """Check if a user is currently rate limited for a request type"""
122
- async with self.lock:
123
- current_minute = int(time.time() / 60)
124
- prev_minute = current_minute - 1
125
-
126
- # Count requests in current and previous minute
127
- current_count = self.time_buckets[user_id][current_minute][request_type]
128
- prev_count = self.time_buckets[user_id][prev_minute][request_type]
129
-
130
- # Calculate requests per minute rate (weighted average)
131
- # Weight current minute more as it's more recent
132
- rate = (current_count * 0.7) + (prev_count * 0.3)
133
-
134
- # Get rate limit based on user role
135
- limit = self.rate_limits.get(role, self.rate_limits['anon']).get(
136
- request_type, self.rate_limits['anon']['other'])
137
-
138
- # Check if rate exceeds limit
139
- return rate >= limit
140
-
141
- def get_metrics(self) -> Dict:
142
- """Get a snapshot of current metrics"""
143
- active_users = {
144
- 'total': len(self.user_metrics),
145
- 'anon': 0,
146
- 'normal': 0,
147
- 'pro': 0,
148
- 'admin': 0,
149
- }
150
-
151
- # Count active users in the last 5 minutes
152
- active_cutoff = time.time() - (5 * 60)
153
- for user_data in self.user_metrics.values():
154
- if user_data['last_active'] >= active_cutoff:
155
- active_users[user_data['role']] += 1
156
-
157
- return {
158
- 'uptime_seconds': int(time.time() - self.start_time),
159
- 'total_requests': dict(self.total_requests),
160
- 'active_users': active_users,
161
- 'active_ips': len(self.ip_sessions),
162
- 'timestamp': datetime.datetime.now().isoformat()
163
- }
164
-
165
- def get_detailed_metrics(self) -> Dict:
166
- """Get detailed metrics including per-user data"""
167
- metrics = self.get_metrics()
168
-
169
- # Add anonymized user metrics
170
- user_list = []
171
- for user_id, data in self.user_metrics.items():
172
- # Skip users inactive for more than 1 hour
173
- if time.time() - data['last_active'] > 3600:
174
- continue
175
-
176
- user_list.append({
177
- 'id': user_id[:8] + '...', # Anonymize ID
178
- 'role': data['role'],
179
- 'requests': data['requests'],
180
- 'active_ago': int(time.time() - data['last_active']),
181
- 'session_duration': int(time.time() - data['first_seen'])
182
- })
183
-
184
- metrics['users'] = user_list
185
- return metrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
reference_example/api_session.py DELETED
@@ -1,569 +0,0 @@
1
- import asyncio
2
- import logging
3
- from typing import Dict, Set
4
- from aiohttp import web, WSMsgType
5
- import json
6
- import time
7
- import datetime
8
- from api_core import VideoGenerationAPI
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- class UserSession:
13
- """
14
- Represents a user's session with the API.
15
- Each WebSocket connection gets its own session with separate queues and rate limits.
16
- """
17
- def __init__(self, user_id: str, user_role: str, ws: web.WebSocketResponse, shared_api):
18
- self.user_id = user_id
19
- self.user_role = user_role
20
- self.ws = ws
21
- self.shared_api = shared_api # For shared resources like endpoint manager
22
-
23
- # Create separate queues for this user session
24
- self.chat_queue = asyncio.Queue()
25
- self.video_queue = asyncio.Queue()
26
- self.search_queue = asyncio.Queue()
27
- self.simulation_queue = asyncio.Queue() # New queue for description evolution
28
-
29
- # Track request counts and rate limits
30
- self.request_counts = {
31
- 'chat': 0,
32
- 'video': 0,
33
- 'search': 0,
34
- 'simulation': 0 # New counter for simulation requests
35
- }
36
-
37
- # Last request timestamps for rate limiting
38
- self.last_request_times = {
39
- 'chat': time.time(),
40
- 'video': time.time(),
41
- 'search': time.time(),
42
- 'simulation': time.time() # New timestamp for simulation requests
43
- }
44
-
45
- # Session creation time
46
- self.created_at = time.time()
47
-
48
- self.background_tasks = []
49
-
50
- async def start(self):
51
- """Start all the queue processors for this session"""
52
- # Start background tasks for handling different request types
53
- self.background_tasks = [
54
- asyncio.create_task(self._process_chat_queue()),
55
- asyncio.create_task(self._process_video_queue()),
56
- asyncio.create_task(self._process_search_queue()),
57
- asyncio.create_task(self._process_simulation_queue()) # New worker for simulation requests
58
- ]
59
- logger.info(f"Started session for user {self.user_id} with role {self.user_role}")
60
-
61
- async def stop(self):
62
- """Stop all background tasks for this session"""
63
- for task in self.background_tasks:
64
- task.cancel()
65
-
66
- try:
67
- # Wait for tasks to complete cancellation
68
- await asyncio.gather(*self.background_tasks, return_exceptions=True)
69
- except asyncio.CancelledError:
70
- pass
71
-
72
- logger.info(f"Stopped session for user {self.user_id}")
73
-
74
- async def _process_chat_queue(self):
75
- """High priority queue for chat operations"""
76
- while True:
77
- data = await self.chat_queue.get()
78
- try:
79
- if data['action'] == 'join_chat':
80
- result = await self.shared_api.handle_join_chat(data, self.ws)
81
- elif data['action'] == 'chat_message':
82
- result = await self.shared_api.handle_chat_message(data, self.ws)
83
- elif data['action'] == 'leave_chat':
84
- result = await self.shared_api.handle_leave_chat(data, self.ws)
85
- # Redirect thumbnail requests to process_generic_request for consistent handling
86
- elif data['action'] == 'generate_video_thumbnail':
87
- # Pass to the generic request handler to maintain consistent logic
88
- await self.process_generic_request(data)
89
- # Skip normal response handling since process_generic_request already sends a response
90
- self.chat_queue.task_done()
91
- continue
92
- else:
93
- raise ValueError(f"Unknown chat action: {data['action']}")
94
-
95
- await self.ws.send_json(result)
96
-
97
- # Update metrics
98
- self.request_counts['chat'] += 1
99
- self.last_request_times['chat'] = time.time()
100
-
101
- except Exception as e:
102
- logger.error(f"Error processing chat request for user {self.user_id}: {e}")
103
- try:
104
- await self.ws.send_json({
105
- 'action': data['action'],
106
- 'requestId': data.get('requestId'),
107
- 'success': False,
108
- 'error': f'Chat error: {str(e)}'
109
- })
110
- except Exception as send_error:
111
- logger.error(f"Error sending error response: {send_error}")
112
- finally:
113
- self.chat_queue.task_done()
114
-
115
- async def _process_video_queue(self):
116
- """Process multiple video generation requests in parallel for this user"""
117
- from api_config import VIDEO_ROUND_ROBIN_ENDPOINT_URLS
118
-
119
- active_tasks = set()
120
- # Set a per-user concurrent limit based on role
121
- max_concurrent = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS)
122
- if self.user_role == 'anon':
123
- max_concurrent = min(2, max_concurrent) # Limit anonymous users
124
- elif self.user_role == 'normal':
125
- max_concurrent = min(4, max_concurrent) # Standard users
126
- # Pro and admin can use all endpoints
127
-
128
- async def process_single_request(data):
129
- try:
130
- title = data.get('title', '')
131
- description = data.get('description', '')
132
- video_prompt_prefix = data.get('video_prompt_prefix', '')
133
- options = data.get('options', {})
134
-
135
- # Pass the user role to generate_video
136
- video_data = await self.shared_api.generate_video(
137
- title, description, video_prompt_prefix, options, self.user_role
138
- )
139
-
140
- result = {
141
- 'action': 'generate_video',
142
- 'requestId': data.get('requestId'),
143
- 'success': True,
144
- 'video': video_data,
145
- }
146
-
147
- await self.ws.send_json(result)
148
-
149
- # Update metrics
150
- self.request_counts['video'] += 1
151
- self.last_request_times['video'] = time.time()
152
-
153
- except Exception as e:
154
- logger.error(f"Error processing video request for user {self.user_id}: {e}")
155
- try:
156
- await self.ws.send_json({
157
- 'action': 'generate_video',
158
- 'requestId': data.get('requestId'),
159
- 'success': False,
160
- 'error': f'Video generation error: {str(e)}'
161
- })
162
- except Exception as send_error:
163
- logger.error(f"Error sending error response: {send_error}")
164
- finally:
165
- active_tasks.discard(asyncio.current_task())
166
-
167
- while True:
168
- # Clean up completed tasks
169
- active_tasks = {task for task in active_tasks if not task.done()}
170
-
171
- # Start new tasks if we have capacity
172
- while len(active_tasks) < max_concurrent:
173
- try:
174
- # Use try_get to avoid blocking if queue is empty
175
- data = await asyncio.wait_for(self.video_queue.get(), timeout=0.1)
176
-
177
- # Create and start new task
178
- task = asyncio.create_task(process_single_request(data))
179
- active_tasks.add(task)
180
-
181
- except asyncio.TimeoutError:
182
- # No items in queue, break inner loop
183
- break
184
- except Exception as e:
185
- logger.error(f"Error creating video generation task for user {self.user_id}: {e}")
186
- break
187
-
188
- # Wait a short time before checking queue again
189
- await asyncio.sleep(0.1)
190
-
191
- # Handle any completed tasks' errors
192
- for task in list(active_tasks):
193
- if task.done():
194
- try:
195
- await task
196
- except Exception as e:
197
- logger.error(f"Task failed with error for user {self.user_id}: {e}")
198
- active_tasks.discard(task)
199
-
200
- async def _process_search_queue(self):
201
- """Medium priority queue for search operations"""
202
- while True:
203
- try:
204
- data = await self.search_queue.get()
205
- request_id = data.get('requestId')
206
- query = data.get('query', '').strip()
207
- attempt_count = data.get('attemptCount', 0)
208
-
209
- # logger.info(f"Processing search request for user {self.user_id}, attempt={attempt_count}")
210
-
211
- if not query:
212
- logger.warning(f"Empty query received in request from user {self.user_id}: {data}")
213
- result = {
214
- 'action': 'search',
215
- 'requestId': request_id,
216
- 'success': False,
217
- 'error': 'No search query provided'
218
- }
219
- else:
220
- try:
221
- search_result = await self.shared_api.search_video(
222
- query,
223
- attempt_count=attempt_count
224
- )
225
-
226
- if search_result:
227
- # logger.info(f"Search successful for user {self.user_id}, query '{query}'")
228
- result = {
229
- 'action': 'search',
230
- 'requestId': request_id,
231
- 'success': True,
232
- 'result': search_result
233
- }
234
- else:
235
- # logger.warning(f"No results found for user {self.user_id}, query '{query}'")
236
- result = {
237
- 'action': 'search',
238
- 'requestId': request_id,
239
- 'success': False,
240
- 'error': 'No results found'
241
- }
242
- except Exception as e:
243
- logger.error(f"Search error for user {self.user_id}, (attempt {attempt_count}): {str(e)}")
244
- result = {
245
- 'action': 'search',
246
- 'requestId': request_id,
247
- 'success': False,
248
- 'error': f'Search error: {str(e)}'
249
- }
250
-
251
- await self.ws.send_json(result)
252
-
253
- # Update metrics
254
- self.request_counts['search'] += 1
255
- self.last_request_times['search'] = time.time()
256
-
257
- except Exception as e:
258
- logger.error(f"Error in search queue processor for user {self.user_id}: {str(e)}")
259
- try:
260
- error_response = {
261
- 'action': 'search',
262
- 'requestId': data.get('requestId') if 'data' in locals() else None,
263
- 'success': False,
264
- 'error': f'Internal server error: {str(e)}'
265
- }
266
- await self.ws.send_json(error_response)
267
- except Exception as send_error:
268
- logger.error(f"Error sending error response: {send_error}")
269
- finally:
270
- if 'search_queue' in self.__dict__:
271
- self.search_queue.task_done()
272
-
273
- async def _process_simulation_queue(self):
274
- """Dedicated queue for video simulation requests"""
275
- while True:
276
- try:
277
- data = await self.simulation_queue.get()
278
- request_id = data.get('requestId')
279
-
280
- # Extract parameters from the request
281
- video_id = data.get('video_id', '')
282
- original_title = data.get('original_title', '')
283
- original_description = data.get('original_description', '')
284
- current_description = data.get('current_description', '')
285
- condensed_history = data.get('condensed_history', '')
286
- evolution_count = data.get('evolution_count', 0)
287
- chat_messages = data.get('chat_messages', '')
288
-
289
- logger.info(f"Processing video simulation for user {self.user_id}, video_id={video_id}, evolution_count={evolution_count}")
290
-
291
- # Validate required parameters
292
- if not original_title or not original_description or not current_description:
293
- result = {
294
- 'action': 'simulate',
295
- 'requestId': request_id,
296
- 'success': False,
297
- 'error': 'Missing required parameters'
298
- }
299
- else:
300
- try:
301
- # Call the simulate method in the API
302
- simulation_result = await self.shared_api.simulate(
303
- original_title=original_title,
304
- original_description=original_description,
305
- current_description=current_description,
306
- condensed_history=condensed_history,
307
- evolution_count=evolution_count,
308
- chat_messages=chat_messages
309
- )
310
-
311
- result = {
312
- 'action': 'simulate',
313
- 'requestId': request_id,
314
- 'success': True,
315
- 'evolved_description': simulation_result['evolved_description'],
316
- 'condensed_history': simulation_result['condensed_history']
317
- }
318
- except Exception as e:
319
- logger.error(f"Error simulating video for user {self.user_id}, video_id={video_id}: {str(e)}")
320
- result = {
321
- 'action': 'simulate',
322
- 'requestId': request_id,
323
- 'success': False,
324
- 'error': f'Simulation error: {str(e)}'
325
- }
326
-
327
- await self.ws.send_json(result)
328
-
329
- # Update metrics
330
- self.request_counts['simulation'] += 1
331
- self.last_request_times['simulation'] = time.time()
332
-
333
- except Exception as e:
334
- logger.error(f"Error in simulation queue processor for user {self.user_id}: {str(e)}")
335
- try:
336
- error_response = {
337
- 'action': 'simulate',
338
- 'requestId': data.get('requestId') if 'data' in locals() else None,
339
- 'success': False,
340
- 'error': f'Internal server error: {str(e)}'
341
- }
342
- await self.ws.send_json(error_response)
343
- except Exception as send_error:
344
- logger.error(f"Error sending error response: {send_error}")
345
- finally:
346
- if 'simulation_queue' in self.__dict__:
347
- self.simulation_queue.task_done()
348
-
349
- async def process_generic_request(self, data: dict) -> None:
350
- """Handle general requests that don't fit into specialized queues"""
351
- try:
352
- request_id = data.get('requestId')
353
- action = data.get('action')
354
-
355
- def error_response(message: str):
356
- return {
357
- 'action': action,
358
- 'requestId': request_id,
359
- 'success': False,
360
- 'error': message
361
- }
362
-
363
- if action == 'heartbeat':
364
- # Include user role info in heartbeat response
365
- await self.ws.send_json({
366
- 'action': 'heartbeat',
367
- 'requestId': request_id,
368
- 'success': True,
369
- 'user_role': self.user_role
370
- })
371
-
372
- elif action == 'get_user_role':
373
- # Return the user role information
374
- await self.ws.send_json({
375
- 'action': 'get_user_role',
376
- 'requestId': request_id,
377
- 'success': True,
378
- 'user_role': self.user_role
379
- })
380
-
381
- elif action == 'generate_caption':
382
- title = data.get('params', {}).get('title')
383
- description = data.get('params', {}).get('description')
384
-
385
- if not title or not description:
386
- await self.ws.send_json(error_response('Missing title or description'))
387
- return
388
-
389
- caption = await self.shared_api.generate_caption(title, description)
390
- await self.ws.send_json({
391
- 'action': action,
392
- 'requestId': request_id,
393
- 'success': True,
394
- 'caption': caption
395
- })
396
-
397
- # evolve_description is now handled by the dedicated simulation queue processor
398
-
399
- elif action == 'generate_video_thumbnail':
400
- title = data.get('title', '') or data.get('params', {}).get('title', '')
401
- description = data.get('description', '') or data.get('params', {}).get('description', '')
402
- video_prompt_prefix = data.get('video_prompt_prefix', '') or data.get('params', {}).get('video_prompt_prefix', '')
403
- options = data.get('options', {}) or data.get('params', {}).get('options', {})
404
-
405
- if not title:
406
- await self.ws.send_json(error_response('Missing title for thumbnail generation'))
407
- return
408
-
409
- # Ensure the options include the thumbnail flag
410
- options['thumbnail'] = True
411
-
412
- # Prioritize thumbnail generation with higher priority
413
- options['priority'] = 'high'
414
-
415
- # Add small size settings if not already specified
416
- if 'width' not in options:
417
- options['width'] = 512 # Default thumbnail width
418
- if 'height' not in options:
419
- options['height'] = 288 # Default 16:9 aspect ratio
420
- if 'num_frames' not in options:
421
- options['num_frames'] = 25 # 1 second @ 25fps
422
-
423
- # Let the API know this is a thumbnail for a specific video
424
- options['video_id'] = data.get('video_id', f"thumbnail-{request_id}")
425
-
426
- logger.info(f"Generating thumbnail for video {options['video_id']} for user {self.user_id}")
427
-
428
- try:
429
- # Generate the thumbnail
430
- thumbnail_data = await self.shared_api.generate_video_thumbnail(
431
- title, description, video_prompt_prefix, options, self.user_role
432
- )
433
-
434
- # Respond with appropriate format based on the parameter names used in the request
435
- if 'thumbnailUrl' in data or 'thumbnailUrl' in data.get('params', {}):
436
- # Legacy format using thumbnailUrl
437
- await self.ws.send_json({
438
- 'action': action,
439
- 'requestId': request_id,
440
- 'success': True,
441
- 'thumbnailUrl': thumbnail_data or "",
442
- })
443
- else:
444
- # New format using thumbnail
445
- await self.ws.send_json({
446
- 'action': action,
447
- 'requestId': request_id,
448
- 'success': True,
449
- 'thumbnail': thumbnail_data,
450
- })
451
- except Exception as e:
452
- logger.error(f"Error generating thumbnail: {str(e)}")
453
- await self.ws.send_json(error_response(f"Thumbnail generation failed: {str(e)}"))
454
-
455
- # Handle deprecated thumbnail actions
456
- elif action == 'generate_thumbnail' or action == 'old_generate_thumbnail':
457
- # Redirect to video thumbnail generation
458
- logger.warning(f"Deprecated thumbnail action '{action}' used, redirecting to generate_video_thumbnail")
459
-
460
- # Extract parameters
461
- title = data.get('title', '') or data.get('params', {}).get('title', '')
462
- description = data.get('description', '') or data.get('params', {}).get('description', '')
463
-
464
- if not title or not description:
465
- await self.ws.send_json(error_response('Missing title or description'))
466
- return
467
-
468
- # Create a new request with the correct action
469
- new_request = {
470
- 'action': 'generate_video_thumbnail',
471
- 'requestId': request_id,
472
- 'title': title,
473
- 'description': description,
474
- 'options': {
475
- 'width': 512,
476
- 'height': 288,
477
- 'thumbnail': True,
478
- 'video_id': f"thumbnail-{request_id}"
479
- }
480
- }
481
-
482
- # Process with the new action
483
- await self.process_generic_request(new_request)
484
-
485
- else:
486
- await self.ws.send_json(error_response(f'Unknown action: {action}'))
487
-
488
- except Exception as e:
489
- logger.error(f"Error processing generic request for user {self.user_id}: {str(e)}")
490
- try:
491
- await self.ws.send_json({
492
- 'action': data.get('action'),
493
- 'requestId': data.get('requestId'),
494
- 'success': False,
495
- 'error': f'Internal server error: {str(e)}'
496
- })
497
- except Exception as send_error:
498
- logger.error(f"Error sending error response: {send_error}")
499
-
500
- class SessionManager:
501
- """
502
- Manages all active user sessions and shared resources.
503
- """
504
- def __init__(self):
505
- self.sessions = {}
506
- self.shared_api = VideoGenerationAPI() # Single instance for shared resources
507
- self.session_lock = asyncio.Lock()
508
-
509
- async def create_session(self, user_id: str, user_role: str, ws: web.WebSocketResponse) -> UserSession:
510
- """Create a new user session"""
511
- async with self.session_lock:
512
- # Create a new session for this user
513
- session = UserSession(user_id, user_role, ws, self.shared_api)
514
- await session.start()
515
- self.sessions[user_id] = session
516
- return session
517
-
518
- async def delete_session(self, user_id: str) -> None:
519
- """Delete a user session and clean up resources"""
520
- async with self.session_lock:
521
- if user_id in self.sessions:
522
- session = self.sessions[user_id]
523
- await session.stop()
524
- del self.sessions[user_id]
525
- logger.info(f"Deleted session for user {user_id}")
526
-
527
- def get_session(self, user_id: str) -> UserSession:
528
- """Get a user session if it exists"""
529
- return self.sessions.get(user_id)
530
-
531
- async def close_all_sessions(self) -> None:
532
- """Close all active sessions (used during shutdown)"""
533
- async with self.session_lock:
534
- for user_id, session in list(self.sessions.items()):
535
- await session.stop()
536
- self.sessions.clear()
537
- logger.info("Closed all active sessions")
538
-
539
- @property
540
- def session_count(self) -> int:
541
- """Get the number of active sessions"""
542
- return len(self.sessions)
543
-
544
- def get_session_stats(self) -> Dict:
545
- """Get statistics about active sessions"""
546
- stats = {
547
- 'total_sessions': len(self.sessions),
548
- 'by_role': {
549
- 'anon': 0,
550
- 'normal': 0,
551
- 'pro': 0,
552
- 'admin': 0
553
- },
554
- 'requests': {
555
- 'chat': 0,
556
- 'video': 0,
557
- 'search': 0,
558
- 'simulation': 0
559
- }
560
- }
561
-
562
- for session in self.sessions.values():
563
- stats['by_role'][session.user_role] += 1
564
- stats['requests']['chat'] += session.request_counts['chat']
565
- stats['requests']['video'] += session.request_counts['video']
566
- stats['requests']['search'] += session.request_counts['search']
567
- stats['requests']['simulation'] += session.request_counts['simulation']
568
-
569
- return stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/agent.py CHANGED
@@ -5,11 +5,11 @@ from typing import Optional, Union
5
  import torch
6
  import torch.nn as nn
7
 
8
- from envs import TorchEnv, WorldModelEnv
9
- from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
10
- from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
11
- from models.rew_end_model import RewEndModel, RewEndModelConfig
12
- from utils import extract_state_dict
13
 
14
 
15
  @dataclass
 
5
  import torch
6
  import torch.nn as nn
7
 
8
+ from src.envs import TorchEnv, WorldModelEnv
9
+ from src.models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
10
+ from src.models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
11
+ from src.models.rew_end_model import RewEndModel, RewEndModelConfig
12
+ from src.utils import extract_state_dict
13
 
14
 
15
  @dataclass
src/data/dataset.py CHANGED
@@ -13,7 +13,7 @@ from torch.utils.data import Dataset as TorchDataset
13
  from .episode import Episode
14
  from .segment import Segment, SegmentId
15
  from .utils import make_segment
16
- from utils import StateDictMixin
17
 
18
 
19
  class Dataset(StateDictMixin, TorchDataset):
 
13
  from .episode import Episode
14
  from .segment import Segment, SegmentId
15
  from .utils import make_segment
16
+ from src.utils import StateDictMixin
17
 
18
 
19
  class Dataset(StateDictMixin, TorchDataset):
src/game/dataset_env.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Tuple
3
  import torch
4
  from torch import Tensor
5
 
6
- from data import Dataset
7
 
8
 
9
  class DatasetEnv:
 
3
  import torch
4
  from torch import Tensor
5
 
6
+ from src.data import Dataset
7
 
8
 
9
  class DatasetEnv:
src/game/game.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import pygame
5
  from PIL import Image
6
 
7
- from player.action_processing import GameAction
8
  from .dataset_env import DatasetEnv
9
  from .play_env import PlayEnv
10
 
 
4
  import pygame
5
  from PIL import Image
6
 
7
+ from src.player.action_processing import GameAction
8
  from .dataset_env import DatasetEnv
9
  from .play_env import PlayEnv
10
 
src/game/play_env.py CHANGED
@@ -7,11 +7,11 @@ import pygame
7
  import torch
8
  from torch import Tensor
9
 
10
- from agent import Agent
11
- from player.action_processing import GameAction, decode_game_action, encode_game_action, print_game_action
12
- from player.keymap import GAME_KEYMAP
13
- from data import Dataset, Episode
14
- from envs import WorldModelEnv
15
 
16
 
17
  NamedEnv = namedtuple("NamedEnv", "name env")
 
7
  import torch
8
  from torch import Tensor
9
 
10
+ from src.agent import Agent
11
+ from src.player.action_processing import GameAction, decode_game_action, encode_game_action, print_game_action
12
+ from src.player.keymap import GAME_KEYMAP
13
+ from src.data import Dataset, Episode
14
+ from src.envs import WorldModelEnv
15
 
16
 
17
  NamedEnv = namedtuple("NamedEnv", "name env")
src/main.py CHANGED
@@ -8,8 +8,8 @@ import torch
8
  from torch.distributed import init_process_group, destroy_process_group
9
  import torch.multiprocessing as mp
10
 
11
- from trainer import Trainer
12
- from utils import skip_if_run_is_over
13
 
14
 
15
  OmegaConf.register_new_resolver("eval", eval)
 
8
  from torch.distributed import init_process_group, destroy_process_group
9
  import torch.multiprocessing as mp
10
 
11
+ from src.trainer import Trainer
12
+ from src.utils import skip_if_run_is_over
13
 
14
 
15
  OmegaConf.register_new_resolver("eval", eval)
src/play.py CHANGED
@@ -7,9 +7,9 @@ from hydra.utils import instantiate
7
  from omegaconf import DictConfig, OmegaConf
8
  import torch
9
 
10
- from agent import Agent
11
- from envs import WorldModelEnv
12
- from game import Game, PlayEnv
13
 
14
 
15
  OmegaConf.register_new_resolver("eval", eval)
 
7
  from omegaconf import DictConfig, OmegaConf
8
  import torch
9
 
10
+ from src.agent import Agent
11
+ from src.envs import WorldModelEnv
12
+ from src.game import Game, PlayEnv
13
 
14
 
15
  OmegaConf.register_new_resolver("eval", eval)