diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..3152023feaef0811cf55ea93ff7d32c71a004d03
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,45 @@
+FROM nvidia/cuda:12.0.1-runtime-ubuntu22.04
+
+# Set environment variables
+ENV PYTHONDONTWRITEBYTECODE=1
+ENV PYTHONUNBUFFERED=1
+ENV DEBIAN_FRONTEND=noninteractive
+
+# Install system dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ python3 \
+ python3-pip \
+ python3-dev \
+ ffmpeg \
+ libsm6 \
+ libxext6 \
+ libxrender-dev \
+ libglib2.0-0 \
+ git \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# Set working directory
+WORKDIR /app
+
+# Copy requirements first to leverage Docker caching
+COPY requirements.txt .
+
+# Install Python dependencies
+RUN pip3 install --no-cache-dir -r requirements.txt
+RUN pip3 install aiohttp
+
+# Install additional required packages
+RUN pip3 install --no-cache-dir torch torchvision torchaudio
+
+# Copy application code
+COPY . .
+
+# Create assets directory if it doesn't exist
+RUN mkdir -p /app/assets
+
+# Expose the port used by the server
+EXPOSE 8080
+
+# Set entry command
+CMD ["python3", "server.py", "--host", "0.0.0.0", "--port", "8080"]
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..08357e62af5bcd3f43c8028206d75238c4db1655
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,22 @@
+MIT License
+
+Copyright (c) 2024 Eloi Alonso
+Copyright (c) 2025 Enigma Labs AI
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 4a1d97e427779dab9d5035c290b67dfaccc18b6d..212b895a1cb25e646708ed05f0bab4bf57d49284 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,170 @@
---
-title: Tikslop Gaming Multiverse
-emoji: 🏃
-colorFrom: purple
-colorTo: purple
+title: Multiverse
+emoji: 🐟
+colorFrom: blue
+colorTo: blue
sdk: docker
-pinned: false
+app_file: server.py
+pinned: true
+short_description: AI Multiplayer World Model
+app_port: 8080
+disable_embedding: false
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# Multiverse: The First AI Multiplayer World Model
+
+🌐 [Enigma-AI website](https://enigma-labs.io/) - 📚 [Technical Blog](https://enigma-labs.io/) - [🤗 Model on Huggingface](https://huggingface.co/Enigma-AI/multiverse) - [🤗 Datasets on Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res) - 𝕏 [Multiverse Tweet](https://x.com/j0nathanj/status/1920516649511244258)
+
+
+ Two human players driving cars in Multiverse
+
+
+
+
+---
+
+## Installation
+```bash
+git clone https://github.com/EnigmaLabsAI/multiverse
+cd multiverse
+pip install -r requirements.txt
+```
+
+### Running the model
+
+```bash
+python src/play.py --compile
+```
+
+> Note on Apple Silicon you must enable CPU fallback for MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1 python src/play.py
+
+When running this command, you will be prompted with the controls. Press `enter` to start:
+
+
+Then the game will be start:
+* To control the silver car at the top screen use the arrow keys.
+* To control the blue car at the bottom use the WASD keys.
+
+
+
+---
+
+
+## Training
+
+Multiverse comprised two models:
+* Denoiser - a world model that simulates a game
+* Upsampler - a model which takes the frames from the denoiser and increases their resolution
+
+### Denoiser training
+
+#### 1. Download the dataset
+Download the Denoiser's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res).
+
+#### 2. Process data for training
+Run the command:
+```bash
+python src/process_denoiser_files.py
+```
+
+#### 3. Edit training configuration
+
+Edit [config/env/racing.yaml](config/env/racing.yaml) and set:
+- `path_data_low_res` to `/low_res`
+- `path_data_full_res` to `/full_res`
+
+Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`:
+```yaml
+train_model: denoiser
+```
+
+#### 4. Launch training run
+
+You can then launch a training run with `python src/main.py`.
+
+
+### Upsampler training
+
+#### 1. Download the dataset
+Download the Upsampler's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res).
+
+#### 2. Process data for training
+Run the command:
+```bash
+python src/process_upsampler_files.py
+```
+
+#### 3. Edit training configuration
+
+Edit [config/env/racing.yaml](config/env/racing.yaml) and set:
+- `path_data_low_res` to `/low_res`
+- `path_data_full_res` to `/full_res`
+
+Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`:
+```yaml
+train_model: upsampler
+```
+
+#### 4. Launch training run
+
+You can then launch a training run with `python src/main.py`.
+
+
+---
+
+## Datasets
+
+1. We've collected over 4 hours of multiplayer (1v1) footage from Gran Turismo 4 at a resolution of 48x64 (per players): [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res).
+
+2. A sparse sampling of full resolution, cropped frames, are availabe in order to train the upsampler at a resolution of 350x530: [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res).
+
+The datasets contain a variety of situations: acceleration, braking, overtakes, crashes, and expert driving for both players.
+You can read about the data collection mechanism [here](https://enigma-labs.io/blog)
+
+Note: The full resolution dataset is only for upsampler training and is not fit for world model training.
+
+---
+
+## Outside resources
+
+- DIAMOND - https://github.com/eloialonso/diamond
+- AI-MarioKart64 - https://github.com/Dere-Wah/AI-MarioKart64
+
+---
+
+## Cloud Gaming Server
+
+This project now includes a WebSocket-based cloud gaming server that allows you to play the game through a web browser.
+
+### Using Docker (Recommended for GPU Servers)
+
+The easiest way to deploy the cloud gaming server on a machine with an NVIDIA GPU is using Docker:
+
+```bash
+# Build the Docker image
+docker build -t ai-game-multiverse .
+
+# Run the container with GPU support
+docker run --gpus all -p 8080:8080 ai-game-multiverse
+```
+
+Then access the web interface at http://yourserver:8080
+
+### Features
+
+- Web-based interface accessible from any modern browser
+- Real-time streaming of AI-generated game frames
+- Keyboard and mouse controls
+- Multiple scene selection
+- WebSocket communication for low-latency interaction
+
+### Usage
+
+1. Access the web interface at http://yourserver:8080
+2. Click "Connect" to establish a WebSocket connection
+3. Select a scene from the dropdown
+4. Click "Start Stream" to begin streaming frames
+5. Use WASD keys for movement, Space for jump, Shift for attack
+6. Mouse controls camera view (click on game area to capture mouse)
+
+Note: The server requires an NVIDIA GPU for optimal performance with the AI models. Without a suitable GPU, it will fall back to using simple placeholder frames.
\ No newline at end of file
diff --git a/config/agent/racing.yaml b/config/agent/racing.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..470a4f2768357c5591005fb8e7b26ac2f8b0a14a
--- /dev/null
+++ b/config/agent/racing.yaml
@@ -0,0 +1,66 @@
+_target_: agent.AgentConfig
+
+denoiser:
+ _target_: models.diffusion.DenoiserConfig
+ sigma_data: 0.5
+ sigma_offset_noise: 0.1
+ noise_previous_obs: true
+ upsampling_factor: null
+ frame_sampling:
+ - count: 4
+ stride: 1
+ - count: 4
+ stride: 4
+ inner_model:
+ _target_: models.diffusion.InnerModelConfig
+ img_channels: 6
+ num_steps_conditioning: 8
+ cond_channels: 2048
+ depths:
+ - 2
+ - 2
+ - 2
+ - 2
+ channels:
+ - 128
+ - 256
+ - 512
+ - 1024
+ attn_depths:
+ - 0
+ - 0
+ - 1
+ - 1
+
+upsampler:
+ _target_: models.diffusion.DenoiserConfig
+ sigma_data: 0.5
+ sigma_offset_noise: 0.1
+ noise_previous_obs: false
+ upsampling_factor: 10
+ upsampling_frame_height: 350
+ upsampling_frame_width: 530
+ inner_model:
+ _target_: models.diffusion.InnerModelConfig
+ img_channels: 6
+ num_steps_conditioning: 0
+ cond_channels: 2048
+ depths:
+ - 2
+ - 2
+ - 2
+ - 2
+ channels:
+ - 64
+ - 64
+ - 128
+ - 256
+ attn_depths:
+ - 0
+ - 0
+ - 0
+ - 0
+
+rew_end_model: null
+
+actor_critic: null
diff --git a/config/env/racing.yaml b/config/env/racing.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..79c035cfe6d02d3dec551bea230fcdcf5f75f842
--- /dev/null
+++ b/config/env/racing.yaml
@@ -0,0 +1,7 @@
+train:
+ id: racing
+ size: [700, 530]
+num_actions: 66
+path_data_low_res: null
+path_data_full_res: null
+keymap: racing
diff --git a/config/trainer.yaml b/config/trainer.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4217362d937e42b6144e4a624bdd2c5e3a5a125f
--- /dev/null
+++ b/config/trainer.yaml
@@ -0,0 +1,113 @@
+defaults:
+ - _self_
+ - env: racing
+ - agent: racing
+ - world_model_env: fast
+
+hydra:
+ job:
+ chdir: True
+
+wandb:
+ mode: offline
+ project: null
+ entity: null
+ name: null
+ group: null
+ tags: null
+
+initialization:
+ path_to_ckpt: null
+ load_denoiser: True
+ load_rew_end_model: True
+ load_actor_critic: True
+
+common:
+ devices: all # int, list of int, cpu, or all
+ seed: null
+ resume: False # do not modify, set by scripts/resume.sh only.
+
+checkpointing:
+ save_agent_every: 5
+ num_to_keep: 11 # number of checkpoints to keep, use null to disable
+
+collection:
+ train:
+ num_envs: 1
+ epsilon: 0.01
+ num_steps_total: 100000
+ first_epoch:
+ min: 5000
+ max: 10000 # null: no maximum
+ threshold_rew: 10
+ steps_per_epoch: 100
+ test:
+ num_envs: 1
+ num_episodes: 4
+ epsilon: 0.0
+ num_final_episodes: 100
+
+static_dataset:
+ path: ${env.path_data_low_res}
+ ignore_sample_weights: True
+
+training:
+ should: True
+ num_final_epochs: 600
+ cache_in_ram: False
+ num_workers_data_loaders: 1
+ model_free: False # if True, turn off world_model training and RL in imagination
+ compile_wm: False
+
+evaluation:
+ should: True
+ every: 20
+
+train_model: denoiser
+
+denoiser:
+ training:
+ num_autoregressive_steps: 8
+ initial_num_consecutive_page_count: 1
+ num_consecutive_pages:
+ - epoch: 400
+ count: 10
+ - epoch: 500
+ count: 50
+ start_after_epochs: 0
+ steps_first_epoch: 10
+ steps_per_epoch: 20
+ sample_weights: null
+ batch_size: 30
+ grad_acc_steps: 2
+ lr_warmup_steps: 100
+ max_grad_norm: 10.0
+
+ optimizer:
+ lr: 1e-4
+ weight_decay: 1e-2
+ eps: 1e-8
+
+ sigma_distribution: # log normal distribution for sigma during training
+ _target_: models.diffusion.SigmaDistributionConfig
+ loc: -1.2
+ scale: 1.2
+ sigma_min: 2e-3
+ sigma_max: 20
+
+upsampler:
+ training:
+ num_autoregressive_steps: 1
+ initial_num_consecutive_page_count: 1
+ start_after_epochs: 0
+ steps_first_epoch: 20
+ steps_per_epoch: 20
+ sample_weights: null
+ batch_size: 4
+ grad_acc_steps: 2
+ lr_warmup_steps: 100
+ max_grad_norm: 10.0
+
+ optimizer: ${denoiser.optimizer}
+ sigma_distribution: ${denoiser.sigma_distribution}
+
diff --git a/config/world_model_env/fast.yaml b/config/world_model_env/fast.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..08eadc86c1cbe70f1ab822844f1194c9319d8120
--- /dev/null
+++ b/config/world_model_env/fast.yaml
@@ -0,0 +1,27 @@
+_target_: envs.WorldModelEnvConfig
+horizon: 1000
+num_batches_to_preload: 256
+diffusion_sampler_next_obs:
+ _target_: models.diffusion.DiffusionSamplerConfig
+ num_steps_denoising: 1
+ sigma_min: 2e-3
+ sigma_max: 5.0
+ rho: 7
+ order: 1 # 1: Euler, 2: Heun
+ s_churn: 0.0 # Amount of stochasticity
+ s_tmin: 0.0
+ s_tmax: ${eval:'float("inf")'}
+ s_noise: 1.0
+ s_cond: 0.005
+diffusion_sampler_upsampling:
+ _target_: models.diffusion.DiffusionSamplerConfig
+ num_steps_denoising: 1
+ sigma_min: 1
+ sigma_max: 5.0
+ rho: 7
+ order: 2 # 1: Euler, 2: Heun
+ s_churn: 10.0 # Amount of stochasticity
+ s_tmin: 1
+ s_tmax: 5
+ s_noise: 0.9
+ s_cond: 0
\ No newline at end of file
diff --git a/example/Dockerfile b/example/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..28542909b9f5203a3d24ebe0234abc82965b4486
--- /dev/null
+++ b/example/Dockerfile
@@ -0,0 +1,59 @@
+FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+ENV PYTHONUNBUFFERED=1
+
+RUN apt-get update && apt-get install --no-install-recommends -y \
+ build-essential \
+ python3.11 \
+ python3-pip \
+ python3-dev \
+ git \
+ curl \
+ ffmpeg \
+ libglib2.0-0 \
+ libsm6 \
+ libxrender1 \
+ libxext6 \
+ ninja-build \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /code
+
+COPY ./requirements.txt /code/requirements.txt
+
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+# Switch to the "user" user
+USER user
+# Set home to the user's home directory
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:$PATH
+
+# Set Python path and environment variables
+ENV PYTHONPATH=$HOME/app \
+ PYTHONUNBUFFERED=1 \
+ DATA_ROOT=/tmp/data
+
+RUN echo "Installing requirements.txt"
+RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
+
+# Install NVIDIA Apex with CUDA and C++ extensions
+RUN cd $HOME && \
+ git clone https://github.com/NVIDIA/apex && \
+ cd apex && \
+ NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--parallel" --global-option="8" ./
+
+WORKDIR $HOME/app
+
+# Copy all files and set proper ownership
+COPY --chown=user . $HOME/app
+
+# Expose the port that server.py uses (8080)
+EXPOSE 8080
+
+ENV PORT 8080
+
+# Run the HF space launcher script which sets up the correct paths
+CMD ["python3", "run_hf_space.py"]
\ No newline at end of file
diff --git a/example/client.js b/example/client.js
new file mode 100644
index 0000000000000000000000000000000000000000..36724958f7e64964cc73355a3d7908abb7a2abe0
--- /dev/null
+++ b/example/client.js
@@ -0,0 +1,603 @@
+// MatrixGame WebSocket Client
+
+// WebSocket connection
+let socket = null;
+let userId = null;
+let isStreaming = false;
+let lastFrameTime = 0;
+let frameCount = 0;
+let fpsUpdateInterval = null;
+
+// DOM Elements
+const connectBtn = document.getElementById('connect-btn');
+const startStreamBtn = document.getElementById('start-stream-btn');
+const stopStreamBtn = document.getElementById('stop-stream-btn');
+const sceneSelect = document.getElementById('scene-select');
+const gameCanvas = document.getElementById('game-canvas');
+const connectionLog = document.getElementById('connection-log');
+const mousePosition = document.getElementById('mouse-position');
+const fpsCounter = document.getElementById('fps-counter');
+const mouseTrackingArea = document.getElementById('mouse-tracking-area');
+
+// Pointer Lock API support check
+const pointerLockSupported = 'pointerLockElement' in document ||
+ 'mozPointerLockElement' in document ||
+ 'webkitPointerLockElement' in document;
+
+// Keyboard DOM elements
+const keyElements = {
+ 'w': document.getElementById('key-w'),
+ 'a': document.getElementById('key-a'),
+ 's': document.getElementById('key-s'),
+ 'd': document.getElementById('key-d'),
+ 'space': document.getElementById('key-space'),
+ 'shift': document.getElementById('key-shift')
+};
+
+// Key mapping to action names
+const keyToAction = {
+ 'w': 'forward',
+ 'arrowup': 'forward',
+ 'a': 'left',
+ 'arrowleft': 'left',
+ 's': 'back',
+ 'arrowdown': 'back',
+ 'd': 'right',
+ 'arrowright': 'right',
+ ' ': 'jump',
+ 'shift': 'attack'
+};
+
+// Key state tracking
+const keyState = {
+ 'forward': false,
+ 'back': false,
+ 'left': false,
+ 'right': false,
+ 'jump': false,
+ 'attack': false
+};
+
+// Mouse state
+const mouseState = {
+ x: 0,
+ y: 0,
+ captured: false
+};
+
+// Test server connectivity before establishing WebSocket
+async function testServerConnectivity() {
+ try {
+ // Get base path by extracting path from the script tag's src attribute
+ let basePath = '';
+ const scriptTags = document.getElementsByTagName('script');
+ for (const script of scriptTags) {
+ if (script.src.includes('client.js')) {
+ const url = new URL(script.src);
+ basePath = url.pathname.replace('/assets/client.js', '');
+ break;
+ }
+ }
+
+ // Try to fetch the debug endpoint to see if the server is accessible
+ const response = await fetch(`${window.location.protocol}//${window.location.host}${basePath}/api/debug`);
+ if (!response.ok) {
+ throw new Error(`Server returned ${response.status}`);
+ }
+
+ const debugInfo = await response.json();
+ logMessage(`Server connection test successful! Server time: ${new Date(debugInfo.server_time * 1000).toLocaleTimeString()}`);
+
+ // Log available routes from server
+ if (debugInfo.all_routes && debugInfo.all_routes.length > 0) {
+ logMessage(`Available routes: ${debugInfo.all_routes.join(', ')}`);
+ }
+
+ // Return the debug info for connection setup
+ return debugInfo;
+ } catch (error) {
+ logMessage(`Server connection test failed: ${error.message}`);
+ return null;
+ }
+}
+
+// Connect to WebSocket server
+async function connectWebSocket() {
+ // First test connectivity to the server
+ logMessage('Testing server connectivity...');
+ const debugInfo = await testServerConnectivity();
+
+ // Use secure WebSocket (wss://) if the page is loaded over HTTPS
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
+
+ // Get base path by extracting path from the script tag's src attribute
+ let basePath = '';
+ if (debugInfo && debugInfo.base_path) {
+ // Use base path from server if available
+ basePath = debugInfo.base_path;
+ logMessage(`Using server-provided base path: ${basePath}`);
+ } else {
+ const scriptTags = document.getElementsByTagName('script');
+ for (const script of scriptTags) {
+ if (script.src.includes('client.js')) {
+ const url = new URL(script.src);
+ basePath = url.pathname.replace('/assets/client.js', '');
+ break;
+ }
+ }
+ }
+
+ // Try both with and without base path for WebSocket connection
+ let serverUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}${basePath}/ws`;
+ logMessage(`Attempting to connect to WebSocket at ${serverUrl}...`);
+
+ // For Hugging Face Spaces, try the direct /ws path if the base path doesn't work
+ const fallbackUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}/ws`;
+
+ try {
+ socket = new WebSocket(serverUrl);
+ setupWebSocketHandlers();
+
+ // Set a timeout to try the fallback URL if the first one doesn't connect
+ setTimeout(() => {
+ if (socket.readyState !== WebSocket.OPEN && socket.readyState !== WebSocket.CONNECTING) {
+ logMessage(`Connection to ${serverUrl} failed. Trying fallback URL: ${fallbackUrl}`);
+ socket = new WebSocket(fallbackUrl);
+ setupWebSocketHandlers();
+ }
+ }, 3000);
+ } catch (error) {
+ logMessage(`Error connecting to WebSocket: ${error.message}`);
+ resetUI();
+ }
+}
+
+// Set up WebSocket event handlers
+function setupWebSocketHandlers() {
+ socket.onopen = () => {
+ logMessage('WebSocket connection established');
+ connectBtn.textContent = 'Disconnect';
+ startStreamBtn.disabled = false;
+ sceneSelect.disabled = false;
+ };
+
+ socket.onmessage = (event) => {
+ const message = JSON.parse(event.data);
+
+ switch (message.action) {
+ case 'welcome':
+ userId = message.userId;
+ logMessage(`Connected with user ID: ${userId}`);
+
+ // Update scene options if server provides them
+ if (message.scenes && Array.isArray(message.scenes)) {
+ sceneSelect.innerHTML = '';
+ message.scenes.forEach(scene => {
+ const option = document.createElement('option');
+ option.value = scene;
+ option.textContent = scene.charAt(0).toUpperCase() + scene.slice(1);
+ sceneSelect.appendChild(option);
+ });
+ }
+ break;
+
+ case 'frame':
+ // Process incoming frame
+ processFrame(message);
+ break;
+
+ case 'start_stream':
+ if (message.success) {
+ isStreaming = true;
+ startStreamBtn.disabled = true;
+ stopStreamBtn.disabled = false;
+ logMessage(`Streaming started: ${message.message}`);
+
+ // Start FPS counter
+ startFpsCounter();
+ } else {
+ logMessage(`Error starting stream: ${message.error}`);
+ }
+ break;
+
+ case 'stop_stream':
+ if (message.success) {
+ isStreaming = false;
+ startStreamBtn.disabled = false;
+ stopStreamBtn.disabled = true;
+ logMessage('Streaming stopped');
+
+ // Stop FPS counter
+ stopFpsCounter();
+ } else {
+ logMessage(`Error stopping stream: ${message.error}`);
+ }
+ break;
+
+ case 'pong':
+ // Server responded to ping
+ break;
+
+ case 'change_scene':
+ if (message.success) {
+ logMessage(`Scene changed to ${message.scene}`);
+ } else {
+ logMessage(`Error changing scene: ${message.error}`);
+ }
+ break;
+
+ default:
+ logMessage(`Received message: ${JSON.stringify(message)}`);
+ }
+ };
+
+ socket.onclose = (event) => {
+ logMessage(`WebSocket connection closed (code: ${event.code}, reason: ${event.reason || 'none given'})`);
+ resetUI();
+ };
+
+ socket.onerror = (error) => {
+ logMessage(`WebSocket error. This is often caused by CORS issues or the server being inaccessible.`);
+ console.error('WebSocket error:', error);
+ resetUI();
+ };
+}
+
+// Disconnect from WebSocket server
+function disconnectWebSocket() {
+ if (socket && socket.readyState === WebSocket.OPEN) {
+ // Stop streaming if active
+ if (isStreaming) {
+ sendStopStream();
+ }
+
+ // Close the socket
+ socket.close();
+ logMessage('Disconnected from server');
+ }
+}
+
+// Start streaming frames
+function sendStartStream() {
+ if (socket && socket.readyState === WebSocket.OPEN) {
+ socket.send(JSON.stringify({
+ action: 'start_stream',
+ requestId: generateRequestId(),
+ fps: 16 // Default FPS
+ }));
+ }
+}
+
+// Stop streaming frames
+function sendStopStream() {
+ if (socket && socket.readyState === WebSocket.OPEN) {
+ socket.send(JSON.stringify({
+ action: 'stop_stream',
+ requestId: generateRequestId()
+ }));
+ }
+}
+
+// Send keyboard input to server
+function sendKeyboardInput(key, pressed) {
+ if (socket && socket.readyState === WebSocket.OPEN) {
+ socket.send(JSON.stringify({
+ action: 'keyboard_input',
+ requestId: generateRequestId(),
+ key: key,
+ pressed: pressed
+ }));
+ }
+}
+
+// Send mouse input to server
+function sendMouseInput(x, y) {
+ if (socket && socket.readyState === WebSocket.OPEN && isStreaming) {
+ socket.send(JSON.stringify({
+ action: 'mouse_input',
+ requestId: generateRequestId(),
+ x: x,
+ y: y
+ }));
+ }
+}
+
+// Change scene
+function sendChangeScene(scene) {
+ if (socket && socket.readyState === WebSocket.OPEN) {
+ socket.send(JSON.stringify({
+ action: 'change_scene',
+ requestId: generateRequestId(),
+ scene: scene
+ }));
+ }
+}
+
+// Process incoming frame
+function processFrame(message) {
+ // Update FPS calculation
+ const now = performance.now();
+ if (lastFrameTime > 0) {
+ frameCount++;
+ }
+ lastFrameTime = now;
+
+ // Update the canvas with the new frame
+ if (message.frameData) {
+ gameCanvas.src = `data:image/jpeg;base64,${message.frameData}`;
+ }
+}
+
+// Generate a random request ID
+function generateRequestId() {
+ return Math.random().toString(36).substring(2, 15);
+}
+
+// Log message to the connection info panel
+function logMessage(message) {
+ const logEntry = document.createElement('div');
+ logEntry.className = 'log-entry';
+
+ const timestamp = new Date().toLocaleTimeString();
+ logEntry.textContent = `[${timestamp}] ${message}`;
+
+ connectionLog.appendChild(logEntry);
+ connectionLog.scrollTop = connectionLog.scrollHeight;
+
+ // Limit number of log entries
+ while (connectionLog.children.length > 100) {
+ connectionLog.removeChild(connectionLog.firstChild);
+ }
+}
+
+// Start FPS counter updates
+function startFpsCounter() {
+ frameCount = 0;
+ lastFrameTime = 0;
+
+ // Update FPS display every second
+ fpsUpdateInterval = setInterval(() => {
+ fpsCounter.textContent = `FPS: ${frameCount}`;
+ frameCount = 0;
+ }, 1000);
+}
+
+// Stop FPS counter updates
+function stopFpsCounter() {
+ if (fpsUpdateInterval) {
+ clearInterval(fpsUpdateInterval);
+ fpsUpdateInterval = null;
+ }
+ fpsCounter.textContent = 'FPS: 0';
+}
+
+// Reset UI to initial state
+function resetUI() {
+ connectBtn.textContent = 'Connect';
+ startStreamBtn.disabled = true;
+ stopStreamBtn.disabled = true;
+ sceneSelect.disabled = true;
+
+ // Reset key indicators
+ for (const key in keyElements) {
+ keyElements[key].classList.remove('active');
+ }
+
+ // Stop FPS counter
+ stopFpsCounter();
+
+ // Reset streaming state
+ isStreaming = false;
+}
+
+// Event Listeners
+connectBtn.addEventListener('click', () => {
+ if (socket && socket.readyState === WebSocket.OPEN) {
+ disconnectWebSocket();
+ } else {
+ connectWebSocket();
+ }
+});
+
+startStreamBtn.addEventListener('click', sendStartStream);
+stopStreamBtn.addEventListener('click', sendStopStream);
+
+sceneSelect.addEventListener('change', () => {
+ sendChangeScene(sceneSelect.value);
+});
+
+// Keyboard event listeners
+document.addEventListener('keydown', (event) => {
+ const key = event.key.toLowerCase();
+
+ // Map key to action
+ let action = keyToAction[key];
+ if (!action && key === ' ') {
+ action = keyToAction[' ']; // Handle spacebar
+ }
+
+ if (action && !keyState[action]) {
+ keyState[action] = true;
+
+ // Update visual indicator
+ const keyElement = keyElements[key] ||
+ (key === ' ' ? keyElements['space'] : null) ||
+ (key === 'shift' ? keyElements['shift'] : null);
+
+ if (keyElement) {
+ keyElement.classList.add('active');
+ }
+
+ // Send to server
+ sendKeyboardInput(action, true);
+ }
+
+ // Prevent default actions for game controls
+ if (Object.keys(keyToAction).includes(key) || key === ' ') {
+ event.preventDefault();
+ }
+});
+
+document.addEventListener('keyup', (event) => {
+ const key = event.key.toLowerCase();
+
+ // Map key to action
+ let action = keyToAction[key];
+ if (!action && key === ' ') {
+ action = keyToAction[' ']; // Handle spacebar
+ }
+
+ if (action && keyState[action]) {
+ keyState[action] = false;
+
+ // Update visual indicator
+ const keyElement = keyElements[key] ||
+ (key === ' ' ? keyElements['space'] : null) ||
+ (key === 'shift' ? keyElements['shift'] : null);
+
+ if (keyElement) {
+ keyElement.classList.remove('active');
+ }
+
+ // Send to server
+ sendKeyboardInput(action, false);
+ }
+});
+
+// Mouse capture functions
+function requestPointerLock() {
+ if (!mouseState.captured && pointerLockSupported) {
+ mouseTrackingArea.requestPointerLock = mouseTrackingArea.requestPointerLock ||
+ mouseTrackingArea.mozRequestPointerLock ||
+ mouseTrackingArea.webkitRequestPointerLock;
+ mouseTrackingArea.requestPointerLock();
+ logMessage('Mouse captured. Press ESC to release.');
+ }
+}
+
+function exitPointerLock() {
+ if (mouseState.captured) {
+ document.exitPointerLock = document.exitPointerLock ||
+ document.mozExitPointerLock ||
+ document.webkitExitPointerLock;
+ document.exitPointerLock();
+ logMessage('Mouse released.');
+ }
+}
+
+// Handle pointer lock change events
+document.addEventListener('pointerlockchange', pointerLockChangeHandler);
+document.addEventListener('mozpointerlockchange', pointerLockChangeHandler);
+document.addEventListener('webkitpointerlockchange', pointerLockChangeHandler);
+
+function pointerLockChangeHandler() {
+ if (document.pointerLockElement === mouseTrackingArea ||
+ document.mozPointerLockElement === mouseTrackingArea ||
+ document.webkitPointerLockElement === mouseTrackingArea) {
+ // Pointer is locked, enable mouse movement tracking
+ mouseState.captured = true;
+ document.addEventListener('mousemove', handleMouseMovement);
+ } else {
+ // Pointer is unlocked, disable mouse movement tracking
+ mouseState.captured = false;
+ document.removeEventListener('mousemove', handleMouseMovement);
+ // Reset mouse state
+ mouseState.x = 0;
+ mouseState.y = 0;
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
+ throttledSendMouseInput();
+ }
+}
+
+// Mouse tracking with pointer lock
+function handleMouseMovement(event) {
+ if (mouseState.captured) {
+ // Use movement for mouse look when captured
+ const sensitivity = 0.005; // Adjust sensitivity
+ mouseState.x += event.movementX * sensitivity;
+ mouseState.y -= event.movementY * sensitivity; // Invert Y for intuitive camera control
+
+ // Clamp values
+ mouseState.x = Math.max(-1, Math.min(1, mouseState.x));
+ mouseState.y = Math.max(-1, Math.min(1, mouseState.y));
+
+ // Update display
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
+
+ // Send to server (throttled)
+ throttledSendMouseInput();
+ }
+}
+
+// Mouse click to capture
+mouseTrackingArea.addEventListener('click', () => {
+ if (!mouseState.captured && isStreaming) {
+ requestPointerLock();
+ }
+});
+
+// Standard mouse tracking for when pointer is not locked
+mouseTrackingArea.addEventListener('mousemove', (event) => {
+ if (!mouseState.captured) {
+ // Calculate normalized coordinates relative to the center of the tracking area
+ const rect = mouseTrackingArea.getBoundingClientRect();
+ const centerX = rect.width / 2;
+ const centerY = rect.height / 2;
+
+ // Calculate relative position from center (-1 to 1)
+ const relX = (event.clientX - rect.left - centerX) / centerX;
+ const relY = (event.clientY - rect.top - centerY) / centerY;
+
+ // Scale down for smoother movement (similar to conditions.py)
+ const scaleFactor = 0.05;
+ mouseState.x = relX * scaleFactor;
+ mouseState.y = -relY * scaleFactor; // Invert Y for intuitive camera control
+
+ // Update display
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
+
+ // Send to server (throttled)
+ throttledSendMouseInput();
+ }
+});
+
+// Throttle mouse movement to avoid flooding the server
+const throttledSendMouseInput = (() => {
+ let lastSentTime = 0;
+ const interval = 50; // milliseconds
+
+ return () => {
+ const now = performance.now();
+ if (now - lastSentTime >= interval) {
+ sendMouseInput(mouseState.x, mouseState.y);
+ lastSentTime = now;
+ }
+ };
+})();
+
+// Toggle panel collapse/expand
+function togglePanel(panelId) {
+ const panel = document.getElementById(panelId);
+ const button = panel.querySelector('.toggle-button');
+
+ if (panel.classList.contains('collapsed')) {
+ // Expand the panel
+ panel.classList.remove('collapsed');
+ button.textContent = '−'; // Minus sign
+ } else {
+ // Collapse the panel
+ panel.classList.add('collapsed');
+ button.textContent = '+'; // Plus sign
+ }
+}
+
+// Initialize the UI
+resetUI();
+
+// Make panel headers clickable
+document.querySelectorAll('.panel-header').forEach(header => {
+ header.addEventListener('click', () => {
+ const panelId = header.parentElement.id;
+ togglePanel(panelId);
+ });
+});
\ No newline at end of file
diff --git a/example/engine.py b/example/engine.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d5f724a95a78a72848435427171ca45a48a6008
--- /dev/null
+++ b/example/engine.py
@@ -0,0 +1,332 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+MatrixGame Engine
+
+This module handles the core rendering and model inference for the MatrixGame project.
+"""
+
+import os
+import logging
+import argparse
+import time
+import torch
+import numpy as np
+from PIL import Image
+import cv2
+from einops import rearrange
+from diffusers.utils import load_image
+from diffusers.video_processor import VideoProcessor
+from typing import Dict, List, Tuple, Any, Optional, Union
+
+# MatrixGame specific imports
+from matrixgame.sample.pipeline_matrixgame import MatrixGameVideoPipeline
+from matrixgame.model_variants import get_dit
+from matrixgame.vae_variants import get_vae
+from matrixgame.encoder_variants import get_text_enc
+from matrixgame.model_variants.matrixgame_dit_src import MGVideoDiffusionTransformerI2V
+from matrixgame.sample.flow_matching_scheduler_matrixgame import FlowMatchDiscreteScheduler
+from teacache_forward import teacache_forward
+
+# Import utility functions
+from utils import (
+ visualize_controls,
+ frame_to_jpeg,
+ load_scene_frames,
+ logger
+)
+
+class MatrixGameEngine:
+ """
+ Core engine for MatrixGame model inference and frame generation.
+ """
+ def __init__(self, args: Optional[argparse.Namespace] = None):
+ """
+ Initialize the MatrixGame engine with configuration parameters.
+
+ Args:
+ args: Optional parsed command line arguments for model configuration
+ """
+ # Set default parameters if args not provided
+ self.frame_width = getattr(args, 'frame_width', 640)
+ self.frame_height = getattr(args, 'frame_height', 360)
+ self.fps = getattr(args, 'fps', 16)
+ self.inference_steps = getattr(args, 'inference_steps', 20)
+ self.guidance_scale = getattr(args, 'guidance_scale', 6.0)
+ self.num_pre_frames = getattr(args, 'num_pre_frames', 3)
+
+ # Initialize state
+ self.frame_count = 0
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.weight_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
+
+ # Model paths from environment or args
+ self.vae_path = os.environ.get("VAE_PATH", "./models/matrixgame/vae/")
+ self.dit_path = os.environ.get("DIT_PATH", "./models/matrixgame/dit/")
+ self.textenc_path = os.environ.get("TEXTENC_PATH", "./models/matrixgame")
+
+ # Cache scene initial frames
+ self.scenes = {
+ 'forest': load_scene_frames('forest', self.frame_width, self.frame_height),
+ 'desert': load_scene_frames('desert', self.frame_width, self.frame_height),
+ 'beach': load_scene_frames('beach', self.frame_width, self.frame_height),
+ 'hills': load_scene_frames('hills', self.frame_width, self.frame_height),
+ 'river': load_scene_frames('river', self.frame_width, self.frame_height),
+ 'icy': load_scene_frames('icy', self.frame_width, self.frame_height),
+ 'mushroom': load_scene_frames('mushroom', self.frame_width, self.frame_height),
+ 'plain': load_scene_frames('plain', self.frame_width, self.frame_height)
+ }
+
+ # Cache initial images for model input
+ self.scene_initial_images = {}
+
+ # Initialize MatrixGame pipeline
+ self.model_loaded = False
+ if torch.cuda.is_available():
+ try:
+ self._init_models()
+ self.model_loaded = True
+ logger.info("MatrixGame models loaded successfully")
+ except Exception as e:
+ logger.error(f"Failed to initialize MatrixGame models: {str(e)}")
+ logger.info("Falling back to frame cycling mode")
+ else:
+ logger.warning("CUDA not available. Using frame cycling mode only.")
+
+ def _init_models(self):
+ """Initialize MatrixGame models (VAE, text encoder, transformer)"""
+ # Initialize flow matching scheduler
+ self.scheduler = FlowMatchDiscreteScheduler(
+ shift=15.0,
+ reverse=True,
+ solver="euler"
+ )
+
+ # Initialize VAE
+ try:
+ self.vae = get_vae("matrixgame", self.vae_path, self.weight_dtype)
+ self.vae.requires_grad_(False)
+ self.vae.eval()
+ self.vae.enable_tiling()
+ logger.info("VAE model loaded successfully")
+ except Exception as e:
+ logger.error(f"Error loading VAE model: {str(e)}")
+ raise
+
+ # Initialize DIT (Transformer)
+ try:
+ dit = MGVideoDiffusionTransformerI2V.from_pretrained(self.dit_path)
+ dit.requires_grad_(False)
+ dit.eval()
+ logger.info("DIT model loaded successfully")
+ except Exception as e:
+ logger.error(f"Error loading DIT model: {str(e)}")
+ raise
+
+ # Initialize text encoder
+ try:
+ self.text_enc = get_text_enc('matrixgame', self.textenc_path, weight_dtype=self.weight_dtype, i2v_type='refiner')
+ logger.info("Text encoder loaded successfully")
+ except Exception as e:
+ logger.error(f"Error loading text encoder: {str(e)}")
+ raise
+
+ # Initialize pipeline
+ try:
+ self.pipeline = MatrixGameVideoPipeline(
+ vae=self.vae.vae,
+ text_encoder=self.text_enc,
+ transformer=dit,
+ scheduler=self.scheduler,
+ ).to(self.weight_dtype).to(self.device)
+ logger.info("Pipeline initialized successfully")
+ except Exception as e:
+ logger.error(f"Error initializing pipeline: {str(e)}")
+ raise
+
+ # Configure teacache for the transformer
+ self.pipeline.transformer.__class__.enable_teacache = True
+ self.pipeline.transformer.__class__.cnt = 0
+ self.pipeline.transformer.__class__.num_steps = self.inference_steps
+ self.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
+ self.pipeline.transformer.__class__.rel_l1_thresh = 0.075
+ self.pipeline.transformer.__class__.previous_modulated_input = None
+ self.pipeline.transformer.__class__.previous_residual = None
+ self.pipeline.transformer.__class__.forward = teacache_forward
+
+ # Preprocess initial images for all scenes
+ for scene_name, frames in self.scenes.items():
+ if frames:
+ # Use first frame as initial image
+ self.scene_initial_images[scene_name] = self._preprocess_image(frames[0])
+
+ def _preprocess_image(self, image_array: np.ndarray) -> torch.Tensor:
+ """
+ Preprocess an image for the model.
+
+ Args:
+ image_array: Input image as numpy array
+
+ Returns:
+ torch.Tensor: Preprocessed image tensor
+ """
+ # Convert numpy array to PIL Image if needed
+ if isinstance(image_array, np.ndarray):
+ image = Image.fromarray(image_array)
+ else:
+ image = image_array
+
+ # Preprocess for VAE
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, 'vae') else 8
+ video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor)
+ initial_image = video_processor.preprocess(image, height=self.frame_height, width=self.frame_width)
+
+ # Add past frames for stability (use same frame repeated)
+ past_frames = initial_image.repeat(self.num_pre_frames, 1, 1, 1)
+ initial_image = torch.cat([initial_image, past_frames], dim=0)
+
+ return initial_image
+
+ def generate_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
+ mouse_condition: Optional[List] = None) -> bytes:
+ """
+ Generate the next frame based on current conditions using MatrixGame model.
+
+ Args:
+ scene_name: Name of the current scene
+ keyboard_condition: Keyboard input state
+ mouse_condition: Mouse input state
+
+ Returns:
+ bytes: JPEG bytes of the frame
+ """
+ # Check if model is loaded
+ if not self.model_loaded or not torch.cuda.is_available():
+ # Fall back to frame cycling for demo mode or if models failed to load
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
+ else:
+ # Use MatrixGame model for frame generation
+ try:
+ # Get initial image for this scene
+ initial_image = self.scene_initial_images.get(scene_name)
+ if initial_image is None:
+ # Use forest as default if we don't have an initial image for this scene
+ initial_image = self.scene_initial_images.get('forest')
+ if initial_image is None:
+ # If we still don't have an initial image, fall back to frame cycling
+ logger.error(f"No initial image available for scene {scene_name}")
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
+
+ # Prepare input tensors (move to device and format correctly)
+ if keyboard_condition is None:
+ keyboard_condition = [[0, 0, 0, 0, 0, 0]]
+ if mouse_condition is None:
+ mouse_condition = [[0, 0]]
+
+ # Convert conditions to tensors
+ keyboard_tensor = torch.tensor(keyboard_condition, dtype=torch.float32)
+ mouse_tensor = torch.tensor(mouse_condition, dtype=torch.float32)
+
+ # Move to device and convert to correct dtype
+ keyboard_tensor = keyboard_tensor.to(self.weight_dtype).to(self.device)
+ mouse_tensor = mouse_tensor.to(self.weight_dtype).to(self.device)
+
+ # Get the first frame from the scene for semantic conditioning
+ scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
+ if not scene_frames:
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
+
+ semantic_image = Image.fromarray(scene_frames[0])
+
+ # Get PIL image version of the frame for visualization
+ for scene_frame in scene_frames:
+ if isinstance(scene_frame, np.ndarray):
+ semantic_image = Image.fromarray(scene_frame)
+ break
+
+ # Generate a single frame with the model
+ # Use fewer inference steps for interactive frame generation
+ with torch.no_grad():
+ # Generate a short video (we'll just use the first frame)
+ # We're using a short length (3 frames) for real-time performance
+ video = self.pipeline(
+ height=self.frame_height,
+ width=self.frame_width,
+ video_length=3, # Generate a very short video for speed
+ mouse_condition=mouse_tensor,
+ keyboard_condition=keyboard_tensor,
+ initial_image=initial_image,
+ num_inference_steps=self.inference_steps,
+ guidance_scale=self.guidance_scale,
+ embedded_guidance_scale=None,
+ data_type="video",
+ vae_ver='884-16c-hy',
+ enable_tiling=True,
+ generator=torch.Generator(device=self.device).manual_seed(42),
+ i2v_type='refiner',
+ semantic_images=semantic_image
+ ).videos[0]
+
+ # Convert video tensor to numpy array (use first frame)
+ video_frame = video[0].permute(1, 2, 0).cpu().numpy()
+ video_frame = (video_frame * 255).astype(np.uint8)
+ frame = video_frame
+
+ # Increment frame counter
+ self.frame_count += 1
+
+ except Exception as e:
+ logger.error(f"Error generating frame with MatrixGame model: {str(e)}")
+ # Fall back to cycling demo frames if model generation fails
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
+
+ # Add visualization of input controls
+ frame = visualize_controls(
+ frame, keyboard_condition, mouse_condition,
+ self.frame_width, self.frame_height
+ )
+
+ # Convert frame to JPEG
+ return frame_to_jpeg(frame, self.frame_height, self.frame_width)
+
+ def _fallback_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
+ mouse_condition: Optional[List] = None) -> bytes:
+ """
+ Generate a fallback frame when model generation fails.
+
+ Args:
+ scene_name: Name of the current scene
+ keyboard_condition: Keyboard input state
+ mouse_condition: Mouse input state
+
+ Returns:
+ bytes: JPEG bytes of the frame
+ """
+ scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
+ frame_idx = self.frame_count % len(scene_frames)
+ frame = scene_frames[frame_idx].copy()
+ self.frame_count += 1
+
+ # Add fallback mode indicator
+ cv2.putText(frame, "Fallback mode",
+ (10, self.frame_height - 20),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
+
+ # Add visualization of input controls
+ frame = visualize_controls(
+ frame, keyboard_condition, mouse_condition,
+ self.frame_width, self.frame_height
+ )
+
+ # Convert frame to JPEG
+ return frame_to_jpeg(frame, self.frame_height, self.frame_width)
+
+ def get_valid_scenes(self) -> List[str]:
+ """
+ Get a list of valid scene names.
+
+ Returns:
+ List[str]: List of valid scene names
+ """
+ return list(self.scenes.keys())
\ No newline at end of file
diff --git a/example/index.html b/example/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..c215658cef7b396ba2435527ea1705a097180f5d
--- /dev/null
+++ b/example/index.html
@@ -0,0 +1,329 @@
+
+
+
+
+
+ MatrixGame Client
+
+
+
+
+
+
+
+
Mouse: 0.00, 0.00
+
FPS: 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Keyboard Controls
+
+
+
+
+
+
W
+
+
+
A
+
S
+
D
+
+
+
SPACE
+
+
+
SHIFT
+
+
+
+ W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right
+ Space = Jump, Shift = Attack
+ Click on game view to capture mouse (ESC to release)
+ Mouse = Look around
+
+
+
+
+
+
+
+
Connection Log
+
+
+
+
+
Waiting to connect...
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/example/requirements.txt b/example/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f903e55a951ef2222ac45cd39ee12716e34c5142
--- /dev/null
+++ b/example/requirements.txt
@@ -0,0 +1,23 @@
+diffusers==0.32.2
+einops==0.8.1
+
+#flash_attn==2.7.4.post1
+flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
+
+ftfy==6.3.1
+imageio==2.34.0
+numpy==1.24.4
+opencv_python==4.9.0.80
+opencv_python_headless==4.9.0.80
+packaging==25.0
+peft==0.14.0
+Pillow==11.2.1
+regex==2024.11.6
+safetensors==0.5.3
+torch==2.5.1
+torchvision==0.20.1
+torchaudio==2.5.1
+transformers==4.47.1
+aiohttp==3.9.3
+jinja2==3.1.3
+python-multipart==0.0.6
\ No newline at end of file
diff --git a/example/server.py b/example/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..67b1703c3761987a855ad09c05f482d339d35495
--- /dev/null
+++ b/example/server.py
@@ -0,0 +1,649 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+MatrixGame Websocket Gaming Server
+
+This script implements a websocket server for the MatrixGame project,
+allowing real-time streaming of game frames based on player inputs.
+"""
+
+import asyncio
+import json
+import logging
+import os
+import pathlib
+import time
+import uuid
+import base64
+import argparse
+from typing import Dict, List, Any, Optional
+from aiohttp import web, WSMsgType
+
+# Import the game engine
+from engine import MatrixGameEngine
+from utils import logger, parse_model_args, setup_gpu_environment
+
+class GameSession:
+ """
+ Represents a user's gaming session.
+ Each WebSocket connection gets its own session with separate queues.
+ """
+ def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager):
+ self.user_id = user_id
+ self.ws = ws
+ self.game_manager = game_manager
+
+ # Create action queue for this user session
+ self.action_queue = asyncio.Queue()
+
+ # Session creation time
+ self.created_at = time.time()
+ self.last_activity = time.time()
+
+ # Game state
+ self.current_scene = "forest" # Default scene
+ self.is_streaming = False
+ self.stream_task = None
+
+ # Current input state
+ self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack
+ self.mouse_state = [0, 0] # x, y
+
+ self.background_tasks = []
+
+ async def start(self):
+ """Start all the queue processors for this session"""
+ self.background_tasks = [
+ asyncio.create_task(self._process_action_queue()),
+ ]
+ logger.info(f"Started game session for user {self.user_id}")
+
+ async def stop(self):
+ """Stop all background tasks for this session"""
+ # Stop streaming if active
+ if self.is_streaming and self.stream_task:
+ self.is_streaming = False
+ self.stream_task.cancel()
+ try:
+ await self.stream_task
+ except asyncio.CancelledError:
+ pass
+
+ # Cancel other background tasks
+ for task in self.background_tasks:
+ task.cancel()
+
+ try:
+ # Wait for tasks to complete cancellation
+ await asyncio.gather(*self.background_tasks, return_exceptions=True)
+ except asyncio.CancelledError:
+ pass
+
+ logger.info(f"Stopped game session for user {self.user_id}")
+
+ async def _process_action_queue(self):
+ """Process game actions from the queue"""
+ while True:
+ data = await self.action_queue.get()
+ try:
+ action_type = data.get('action')
+
+ if action_type == 'start_stream':
+ result = await self._handle_start_stream(data)
+ elif action_type == 'stop_stream':
+ result = await self._handle_stop_stream(data)
+ elif action_type == 'keyboard_input':
+ result = await self._handle_keyboard_input(data)
+ elif action_type == 'mouse_input':
+ result = await self._handle_mouse_input(data)
+ elif action_type == 'change_scene':
+ result = await self._handle_scene_change(data)
+ else:
+ result = {
+ 'action': action_type,
+ 'requestId': data.get('requestId'),
+ 'success': False,
+ 'error': f'Unknown action: {action_type}'
+ }
+
+ # Send response back to the client
+ await self.ws.send_json(result)
+
+ # Update last activity time
+ self.last_activity = time.time()
+
+ except Exception as e:
+ logger.error(f"Error processing action for user {self.user_id}: {str(e)}")
+ try:
+ await self.ws.send_json({
+ 'action': data.get('action'),
+ 'requestId': data.get('requestId', 'unknown'),
+ 'success': False,
+ 'error': f'Error processing action: {str(e)}'
+ })
+ except Exception as send_error:
+ logger.error(f"Error sending error response: {send_error}")
+ finally:
+ self.action_queue.task_done()
+
+ async def _handle_start_stream(self, data: Dict) -> Dict:
+ """Handle request to start streaming frames"""
+ if self.is_streaming:
+ return {
+ 'action': 'start_stream',
+ 'requestId': data.get('requestId'),
+ 'success': False,
+ 'error': 'Stream already active'
+ }
+
+ fps = data.get('fps', 16)
+ self.is_streaming = True
+ self.stream_task = asyncio.create_task(self._stream_frames(fps))
+
+ return {
+ 'action': 'start_stream',
+ 'requestId': data.get('requestId'),
+ 'success': True,
+ 'message': f'Streaming started at {fps} FPS'
+ }
+
+ async def _handle_stop_stream(self, data: Dict) -> Dict:
+ """Handle request to stop streaming frames"""
+ if not self.is_streaming:
+ return {
+ 'action': 'stop_stream',
+ 'requestId': data.get('requestId'),
+ 'success': False,
+ 'error': 'No active stream to stop'
+ }
+
+ self.is_streaming = False
+ if self.stream_task:
+ self.stream_task.cancel()
+ try:
+ await self.stream_task
+ except asyncio.CancelledError:
+ pass
+ self.stream_task = None
+
+ return {
+ 'action': 'stop_stream',
+ 'requestId': data.get('requestId'),
+ 'success': True,
+ 'message': 'Streaming stopped'
+ }
+
+ async def _handle_keyboard_input(self, data: Dict) -> Dict:
+ """Handle keyboard input from client"""
+ key = data.get('key', '')
+ pressed = data.get('pressed', False)
+
+ # Map key to keyboard state index
+ key_map = {
+ 'w': 0, 'forward': 0,
+ 's': 1, 'back': 1, 'backward': 1,
+ 'a': 2, 'left': 2,
+ 'd': 3, 'right': 3,
+ 'space': 4, 'jump': 4,
+ 'shift': 5, 'attack': 5, 'ctrl': 5
+ }
+
+ if key.lower() in key_map:
+ key_idx = key_map[key.lower()]
+ self.keyboard_state[key_idx] = 1 if pressed else 0
+
+ return {
+ 'action': 'keyboard_input',
+ 'requestId': data.get('requestId'),
+ 'success': True,
+ 'keyboardState': self.keyboard_state
+ }
+
+ async def _handle_mouse_input(self, data: Dict) -> Dict:
+ """Handle mouse movement/input from client"""
+ mouse_x = data.get('x', 0)
+ mouse_y = data.get('y', 0)
+
+ # Update mouse state, normalize values between -1 and 1
+ self.mouse_state = [float(mouse_x), float(mouse_y)]
+
+ return {
+ 'action': 'mouse_input',
+ 'requestId': data.get('requestId'),
+ 'success': True,
+ 'mouseState': self.mouse_state
+ }
+
+ async def _handle_scene_change(self, data: Dict) -> Dict:
+ """Handle scene change requests"""
+ scene_name = data.get('scene', 'forest')
+ valid_scenes = self.game_manager.valid_scenes
+
+ if scene_name not in valid_scenes:
+ return {
+ 'action': 'change_scene',
+ 'requestId': data.get('requestId'),
+ 'success': False,
+ 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}'
+ }
+
+ self.current_scene = scene_name
+
+ return {
+ 'action': 'change_scene',
+ 'requestId': data.get('requestId'),
+ 'success': True,
+ 'scene': scene_name
+ }
+
+ async def _stream_frames(self, fps: int):
+ """Stream frames to the client at the specified FPS"""
+ frame_interval = 1.0 / fps # Time between frames in seconds
+
+ try:
+ while self.is_streaming:
+ start_time = time.time()
+
+ # Generate frame based on current keyboard and mouse state
+ keyboard_condition = [self.keyboard_state]
+ mouse_condition = [self.mouse_state]
+
+ # Use the engine to generate the next frame
+ frame_bytes = self.game_manager.engine.generate_frame(
+ self.current_scene, keyboard_condition, mouse_condition
+ )
+
+ # Encode as base64 for sending in JSON
+ frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
+
+ # Send frame to client
+ await self.ws.send_json({
+ 'action': 'frame',
+ 'frameData': frame_base64,
+ 'timestamp': time.time()
+ })
+
+ # Calculate sleep time to maintain FPS
+ elapsed = time.time() - start_time
+ sleep_time = max(0, frame_interval - elapsed)
+ await asyncio.sleep(sleep_time)
+
+ except asyncio.CancelledError:
+ logger.info(f"Frame streaming cancelled for user {self.user_id}")
+ except Exception as e:
+ logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}")
+ if self.ws.closed:
+ logger.info(f"WebSocket closed for user {self.user_id}")
+ return
+
+ # Notify client of error
+ try:
+ await self.ws.send_json({
+ 'action': 'frame_error',
+ 'error': f'Streaming error: {str(e)}'
+ })
+ except:
+ pass
+
+ # Stop streaming
+ self.is_streaming = False
+
+class GameManager:
+ """
+ Manages all active gaming sessions and shared resources.
+ """
+ def __init__(self, args: argparse.Namespace):
+ self.sessions = {}
+ self.session_lock = asyncio.Lock()
+
+ # Initialize game engine
+ self.engine = MatrixGameEngine(args)
+
+ # Load valid scenes from engine
+ self.valid_scenes = self.engine.get_valid_scenes()
+
+ async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession:
+ """Create a new game session"""
+ async with self.session_lock:
+ # Create a new session for this user
+ session = GameSession(user_id, ws, self)
+ await session.start()
+ self.sessions[user_id] = session
+ return session
+
+ async def delete_session(self, user_id: str) -> None:
+ """Delete a game session and clean up resources"""
+ async with self.session_lock:
+ if user_id in self.sessions:
+ session = self.sessions[user_id]
+ await session.stop()
+ del self.sessions[user_id]
+ logger.info(f"Deleted game session for user {user_id}")
+
+ def get_session(self, user_id: str) -> Optional[GameSession]:
+ """Get a game session if it exists"""
+ return self.sessions.get(user_id)
+
+ async def close_all_sessions(self) -> None:
+ """Close all active sessions (used during shutdown)"""
+ async with self.session_lock:
+ for user_id, session in list(self.sessions.items()):
+ await session.stop()
+ self.sessions.clear()
+ logger.info("Closed all active game sessions")
+
+ @property
+ def session_count(self) -> int:
+ """Get the number of active sessions"""
+ return len(self.sessions)
+
+ def get_session_stats(self) -> Dict:
+ """Get statistics about active sessions"""
+ stats = {
+ 'total_sessions': len(self.sessions),
+ 'active_scenes': {},
+ 'streaming_sessions': 0
+ }
+
+ # Count sessions by scene and streaming status
+ for session in self.sessions.values():
+ scene = session.current_scene
+ stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1
+ if session.is_streaming:
+ stats['streaming_sessions'] += 1
+
+ return stats
+
+# Create global game manager
+game_manager = None
+
+async def status_handler(request: web.Request) -> web.Response:
+ """Handler for API status endpoint"""
+ # Get session statistics
+ session_stats = game_manager.get_session_stats()
+
+ return web.json_response({
+ 'product': 'MatrixGame WebSocket Server',
+ 'version': '1.0.0',
+ 'active_sessions': session_stats,
+ 'available_scenes': game_manager.valid_scenes
+ })
+
+async def root_handler(request: web.Request) -> web.Response:
+ """Handler for serving the client at the root path"""
+ client_path = pathlib.Path(__file__).parent / 'client' / 'index.html'
+
+ with open(client_path, 'r') as file:
+ html_content = file.read()
+
+ return web.Response(text=html_content, content_type='text/html')
+
+async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
+ """Handle WebSocket connections with robust error handling"""
+ logger.info(f"WebSocket connection attempt - PATH: {request.path}, QUERY: {request.query_string}")
+
+ # Log request headers at debug level only (could contain sensitive information)
+ logger.debug(f"WebSocket request headers: {dict(request.headers)}")
+
+ # Prepare a WebSocket response with appropriate settings
+ ws = web.WebSocketResponse(
+ max_msg_size=1024*1024*10, # 10MB max message size
+ timeout=60.0,
+ heartbeat=30.0 # Add heartbeat to keep connection alive
+ )
+
+ # Check if WebSocket protocol is supported
+ if not ws.can_prepare(request):
+ logger.error("Cannot prepare WebSocket: WebSocket protocol not supported")
+ return web.Response(status=400, text="WebSocket protocol not supported")
+
+ try:
+ logger.info("Preparing WebSocket connection...")
+ await ws.prepare(request)
+
+ # Generate a unique user ID for this connection
+ user_id = str(uuid.uuid4())
+
+ # Get client IP address
+ peername = request.transport.get_extra_info('peername')
+ if peername is not None:
+ client_ip = peername[0]
+ else:
+ client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
+
+ # Log connection success
+ logger.info(f"Client {user_id} connecting from IP: {client_ip} - WebSocket connection established")
+
+ # Mark that the session is established
+ is_session_created = False
+
+ try:
+ # Store the user ID in the websocket for easy access
+ ws.user_id = user_id
+
+ # Create a new session for this user
+ logger.info(f"Creating game session for user {user_id}")
+ user_session = await game_manager.create_session(user_id, ws)
+ is_session_created = True
+ logger.info(f"Game session created for user {user_id}")
+ except Exception as session_error:
+ logger.error(f"Error creating game session: {str(session_error)}", exc_info=True)
+ if not ws.closed:
+ await ws.close(code=1011, message=f"Server error: {str(session_error)}".encode())
+ if is_session_created:
+ await game_manager.delete_session(user_id)
+ return ws
+ except Exception as e:
+ logger.error(f"Error establishing WebSocket connection: {str(e)}", exc_info=True)
+ if not ws.closed and ws.prepared:
+ await ws.close(code=1011, message=f"Server error: {str(e)}".encode())
+ return ws
+
+ # Send initial welcome message
+ try:
+ await ws.send_json({
+ 'action': 'welcome',
+ 'userId': user_id,
+ 'message': 'Welcome to the MatrixGame WebSocket server!',
+ 'scenes': game_manager.valid_scenes
+ })
+ logger.info(f"Sent welcome message to user {user_id}")
+ except Exception as welcome_error:
+ logger.error(f"Error sending welcome message: {str(welcome_error)}")
+ if not ws.closed:
+ await ws.close(code=1011, message=b"Failed to send welcome message")
+ await game_manager.delete_session(user_id)
+ return ws
+
+ try:
+ async for msg in ws:
+ if msg.type == WSMsgType.TEXT:
+ try:
+ data = json.loads(msg.data)
+ action = data.get('action')
+
+ logger.debug(f"Received {action} message from user {user_id}")
+
+ if action == 'ping':
+ # Respond to ping immediately
+ await ws.send_json({
+ 'action': 'pong',
+ 'requestId': data.get('requestId'),
+ 'timestamp': time.time()
+ })
+ else:
+ # Route game actions to the session's action queue
+ await user_session.action_queue.put(data)
+
+ except json.JSONDecodeError:
+ logger.error(f"Invalid JSON from user {user_id}: {msg.data}")
+ if not ws.closed:
+ await ws.send_json({
+ 'error': 'Invalid JSON message',
+ 'success': False
+ })
+ except Exception as e:
+ logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
+ if not ws.closed:
+ await ws.send_json({
+ 'action': data.get('action') if 'data' in locals() else 'unknown',
+ 'success': False,
+ 'error': f'Error processing message: {str(e)}'
+ })
+
+ elif msg.type == WSMsgType.ERROR:
+ logger.error(f"WebSocket error for user {user_id}: {ws.exception()}")
+ break
+
+ elif msg.type == WSMsgType.CLOSE:
+ logger.info(f"WebSocket close received for user {user_id} (code: {msg.data}, message: {msg.extra})")
+ break
+
+ elif msg.type == WSMsgType.CLOSING:
+ logger.info(f"WebSocket closing for user {user_id}")
+ break
+
+ elif msg.type == WSMsgType.CLOSED:
+ logger.info(f"WebSocket already closed for user {user_id}")
+ break
+
+ except Exception as ws_error:
+ logger.error(f"Unexpected WebSocket error for user {user_id}: {str(ws_error)}", exc_info=True)
+ finally:
+ # Cleanup session
+ try:
+ logger.info(f"Cleaning up session for user {user_id}")
+ await game_manager.delete_session(user_id)
+ logger.info(f"Connection closed for user {user_id}")
+ except Exception as cleanup_error:
+ logger.error(f"Error during session cleanup for user {user_id}: {str(cleanup_error)}")
+
+ return ws
+
+async def init_app(args, base_path="") -> web.Application:
+ """Initialize the web application"""
+ global game_manager
+
+ # Initialize game manager with command line args
+ game_manager = GameManager(args)
+
+ app = web.Application(
+ client_max_size=1024**2*10 # 10MB max size
+ )
+
+ # Add cleanup logic
+ async def cleanup(app):
+ logger.info("Shutting down server, closing all sessions...")
+ await game_manager.close_all_sessions()
+
+ app.on_shutdown.append(cleanup)
+
+ # Add routes with CORS headers for WebSockets
+ # Configure CORS for all routes
+ @web.middleware
+ async def cors_middleware(request, handler):
+ if request.method == 'OPTIONS':
+ # Handle preflight requests
+ resp = web.Response()
+ resp.headers['Access-Control-Allow-Origin'] = '*'
+ resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
+ resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
+ return resp
+
+ # Normal request, call the handler
+ resp = await handler(request)
+
+ # Add CORS headers to the response
+ resp.headers['Access-Control-Allow-Origin'] = '*'
+ resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
+ resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
+ return resp
+
+ app.middlewares.append(cors_middleware)
+
+ # Add a debug endpoint to help diagnose WebSocket issues
+ async def debug_handler(request):
+ client_ip = request.remote
+ headers = dict(request.headers)
+ server_host = request.host
+
+ debug_info = {
+ "client_ip": client_ip,
+ "server_host": server_host,
+ "headers": headers,
+ "request_path": request.path,
+ "server_time": time.time(),
+ "base_path": base_path,
+ "websocket_route": f"{base_path}/ws",
+ "all_routes": [route.name for route in app.router.routes() if route.name],
+ "server_info": {
+ "active_sessions": game_manager.session_count,
+ "available_scenes": game_manager.valid_scenes
+ }
+ }
+
+ return web.json_response(debug_info)
+
+ # Set up routes with the base_path
+ # Add multiple WebSocket routes to ensure compatibility
+ logger.info(f"Setting up WebSocket route at {base_path}/ws")
+ app.router.add_get(f'{base_path}/ws', websocket_handler, name='ws_handler')
+
+ # Also add WebSocket route at the root for Hugging Face compatibility
+ if base_path:
+ logger.info(f"Adding additional WebSocket route at /ws")
+ app.router.add_get('/ws', websocket_handler, name='ws_root_handler')
+
+ # Add routes for API and debug endpoints
+ app.router.add_get(f'{base_path}/api/status', status_handler, name='status_handler')
+ app.router.add_get(f'{base_path}/api/debug', debug_handler, name='debug_handler')
+
+ # Serve the client at both the base path and root path for compatibility
+ app.router.add_get(f'{base_path}/', root_handler, name='root_handler')
+
+ # Always serve at the root path for Hugging Face Spaces compatibility
+ if base_path:
+ app.router.add_get('/', root_handler, name='root_handler_no_base')
+
+ # Set up static file serving for the client assets
+ app.router.add_static(f'{base_path}/assets', pathlib.Path(__file__).parent / 'client', name='static_handler')
+
+ # Add static file serving at root for compatibility
+ if base_path:
+ app.router.add_static('/assets', pathlib.Path(__file__).parent / 'client', name='static_handler_no_base')
+
+ return app
+
+def parse_args() -> argparse.Namespace:
+ """Parse server-specific command line arguments"""
+ parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server")
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to")
+ parser.add_argument("--port", type=int, default=8080, help="Port to listen on")
+ parser.add_argument("--path", type=str, default="", help="Base path for the server (for proxy setups)")
+
+ # Parse server args first
+ server_args, remaining_args = parser.parse_known_args()
+
+ # Parse model args and combine
+ model_args = parse_model_args()
+
+ # Combine all args
+ combined_args = argparse.Namespace(**vars(server_args), **vars(model_args))
+
+ return combined_args
+
+if __name__ == '__main__':
+ # Configure GPU environment
+ setup_gpu_environment()
+
+ # Parse command line arguments
+ args = parse_args()
+
+ # Initialize app
+ loop = asyncio.get_event_loop()
+ app = loop.run_until_complete(init_app(args, base_path=args.path))
+
+ # Start server
+ logger.info(f"Starting MatrixGame WebSocket Server at {args.host}:{args.port}")
+ web.run_app(app, host=args.host, port=args.port)
\ No newline at end of file
diff --git a/example/utils.py b/example/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2c35aa64e6d4f4cc1452b3c4789143ce2be6ef7
--- /dev/null
+++ b/example/utils.py
@@ -0,0 +1,202 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""
+MatrixGame Utility Functions
+
+This module contains helper functions and utilities for the MatrixGame project.
+"""
+
+import os
+import logging
+import argparse
+import torch
+import numpy as np
+import cv2
+from PIL import Image
+from typing import Dict, List, Tuple, Any, Optional, Union
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+def setup_gpu_environment():
+ """
+ Configure the GPU environment and log GPU information.
+
+ Returns:
+ bool: True if CUDA is available, False otherwise
+ """
+ # Set CUDA memory allocation environment variable for better performance
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
+
+ # Check if CUDA is available and log information
+ if torch.cuda.is_available():
+ gpu_count = torch.cuda.device_count()
+ gpu_info = []
+
+ for i in range(gpu_count):
+ gpu_name = torch.cuda.get_device_name(i)
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3) # Convert to GB
+ gpu_info.append(f"GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)")
+
+ logger.info(f"CUDA is available. Found {gpu_count} GPU(s):")
+ for info in gpu_info:
+ logger.info(f" {info}")
+ return True
+ else:
+ logger.warning("CUDA is not available. Running in CPU-only mode.")
+ return False
+
+def parse_model_args() -> argparse.Namespace:
+ """
+ Parse command line arguments for model paths and configuration.
+
+ Returns:
+ argparse.Namespace: Parsed arguments
+ """
+ parser = argparse.ArgumentParser(description="MatrixGame Model Configuration")
+
+ # Model paths
+ parser.add_argument("--model_root", type=str, default="./models/matrixgame",
+ help="Root directory for model files")
+ parser.add_argument("--dit_path", type=str, default=None,
+ help="Path to DIT model. If not provided, will use MODEL_ROOT/dit/")
+ parser.add_argument("--vae_path", type=str, default=None,
+ help="Path to VAE model. If not provided, will use MODEL_ROOT/vae/")
+ parser.add_argument("--textenc_path", type=str, default=None,
+ help="Path to text encoder model. If not provided, will use MODEL_ROOT")
+
+ # Model settings
+ parser.add_argument("--inference_steps", type=int, default=20,
+ help="Number of inference steps for frame generation (lower is faster)")
+ parser.add_argument("--guidance_scale", type=float, default=6.0,
+ help="Guidance scale for generation")
+ parser.add_argument("--frame_width", type=int, default=640,
+ help="Width of the generated frames")
+ parser.add_argument("--frame_height", type=int, default=360,
+ help="Height of the generated frames")
+ parser.add_argument("--num_pre_frames", type=int, default=3,
+ help="Number of pre-frames for conditioning")
+ parser.add_argument("--fps", type=int, default=16,
+ help="Frames per second for video")
+
+ args = parser.parse_args()
+
+ # Set environment variables for model paths if provided
+ if args.model_root:
+ os.environ.setdefault("MODEL_ROOT", args.model_root)
+ if args.dit_path:
+ os.environ.setdefault("DIT_PATH", args.dit_path)
+ else:
+ os.environ.setdefault("DIT_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "dit/"))
+ if args.vae_path:
+ os.environ.setdefault("VAE_PATH", args.vae_path)
+ else:
+ os.environ.setdefault("VAE_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "vae/"))
+ if args.textenc_path:
+ os.environ.setdefault("TEXTENC_PATH", args.textenc_path)
+ else:
+ os.environ.setdefault("TEXTENC_PATH", os.environ.get("MODEL_ROOT", "./models/matrixgame"))
+
+ return args
+
+def visualize_controls(frame: np.ndarray, keyboard_condition: List, mouse_condition: List,
+ frame_width: int, frame_height: int) -> np.ndarray:
+ """
+ Visualize keyboard and mouse controls on the frame.
+
+ Args:
+ frame: The video frame to visualize on
+ keyboard_condition: Keyboard state as a list
+ mouse_condition: Mouse state as a list
+ frame_width: Width of the frame
+ frame_height: Height of the frame
+
+ Returns:
+ np.ndarray: Frame with visualized controls
+ """
+ # Clone the frame to avoid modifying the original
+ frame = frame.copy()
+
+ # If we have keyboard/mouse conditions, visualize them on the frame
+ if keyboard_condition:
+ # Visualize keyboard inputs
+ keys = ["W", "S", "A", "D", "JUMP", "ATTACK"]
+ for i, key_pressed in enumerate(keyboard_condition[0]):
+ color = (0, 255, 0) if key_pressed else (100, 100, 100)
+ cv2.putText(frame, keys[i], (20 + i*100, 30),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
+
+ if mouse_condition:
+ # Visualize mouse movement
+ mouse_x, mouse_y = mouse_condition[0]
+ # Scale mouse values for visualization
+ offset_x = int(mouse_x * 100)
+ offset_y = int(mouse_y * 100)
+ center_x, center_y = frame_width // 2, frame_height // 2
+ cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1)
+ cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}",
+ (frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
+
+ return frame
+
+def frame_to_jpeg(frame: np.ndarray, frame_height: int, frame_width: int) -> bytes:
+ """
+ Convert a frame to JPEG bytes.
+
+ Args:
+ frame: The video frame to convert
+ frame_height: Height of the frame for fallback
+ frame_width: Width of the frame for fallback
+
+ Returns:
+ bytes: JPEG bytes of the frame
+ """
+ success, buffer = cv2.imencode('.jpg', frame)
+ if not success:
+ logger.error("Failed to encode frame as JPEG")
+ # Return a blank frame
+ blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 100
+ success, buffer = cv2.imencode('.jpg', blank)
+
+ return buffer.tobytes()
+
+def load_scene_frames(scene_name: str, frame_width: int, frame_height: int) -> List[np.ndarray]:
+ """
+ Load initial frames for a scene from asset directory.
+
+ Args:
+ scene_name: Name of the scene
+ frame_width: Width to resize frames to
+ frame_height: Height to resize frames to
+
+ Returns:
+ List[np.ndarray]: List of frames as numpy arrays
+ """
+ frames = []
+ scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}"
+
+ if os.path.exists(scene_dir):
+ image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')])
+ for img_file in image_files:
+ try:
+ img_path = os.path.join(scene_dir, img_file)
+ img = Image.open(img_path).convert("RGB")
+ img = img.resize((frame_width, frame_height))
+ frames.append(np.array(img))
+ except Exception as e:
+ logger.error(f"Error loading image {img_file}: {str(e)}")
+
+ # If no frames were loaded, create a default colored frame with text
+ if not frames:
+ frame = np.ones((frame_height, frame_height, 3), dtype=np.uint8) * 100
+ # Add scene name as text
+ cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
+ frames.append(frame)
+
+ return frames
\ No newline at end of file
diff --git a/game/spawn/1/act.npy b/game/spawn/1/act.npy
new file mode 100644
index 0000000000000000000000000000000000000000..2924767c1a504eb1e15ce0ac28d39624e61fe7d0
--- /dev/null
+++ b/game/spawn/1/act.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7caabf75f45d4c8bae5c0b66dc2b5a3cbf3ab7dbf89521d6ba539c4f30048d75
+size 10688
diff --git a/game/spawn/1/full_res.npy b/game/spawn/1/full_res.npy
new file mode 100644
index 0000000000000000000000000000000000000000..4e6c6c2ff258a8d56c070de131ed13d63968d5d7
--- /dev/null
+++ b/game/spawn/1/full_res.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c186e607a8cc4922e17ee66c1f37dc4858adfef220b6bf48fbcda9bf75ffde34
+size 22260128
diff --git a/game/spawn/1/low_res.npy b/game/spawn/1/low_res.npy
new file mode 100644
index 0000000000000000000000000000000000000000..e7b92ab83886e42ab0fc6d3c4146d34a846093a0
--- /dev/null
+++ b/game/spawn/1/low_res.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08e31f7ef447a1dffda3b51b4102ae4301f5804145b63c71acc5c278a294b1ee
+size 368768
diff --git a/game/spawn/1/next_act.npy b/game/spawn/1/next_act.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7c6afd9bd0075b320e33f95e49cad10748b53dbe
--- /dev/null
+++ b/game/spawn/1/next_act.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0f2d1c96337459ddd84b9c1a8dbad9eb7284cb813a2f0ebd9cb4d757dd294e1
+size 105728
diff --git a/game/spawn/2/act.npy b/game/spawn/2/act.npy
new file mode 100644
index 0000000000000000000000000000000000000000..4745e120d72bd68c4de9b05c7b31ae9b6ba17894
--- /dev/null
+++ b/game/spawn/2/act.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:619cdda3de7f55e48a753b64542354387294c2688a6e014af85b094996b8a486
+size 10688
diff --git a/game/spawn/2/full_res.npy b/game/spawn/2/full_res.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7d5357363ae4968b166615ce8bde52240a94fcde
--- /dev/null
+++ b/game/spawn/2/full_res.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:911389a5957acd8d7b96fccb1917ee4a4b0c74f0b9420f61580b571965dd99ff
+size 22260128
diff --git a/game/spawn/2/low_res.npy b/game/spawn/2/low_res.npy
new file mode 100644
index 0000000000000000000000000000000000000000..dc811a357c4143af270640a78f2a1a0be290bf2c
--- /dev/null
+++ b/game/spawn/2/low_res.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8351be59118b4c2237800a9119937d472564e63c70a8d1911bc2d52ac3a95a2
+size 368768
diff --git a/game/spawn/2/next_act.npy b/game/spawn/2/next_act.npy
new file mode 100644
index 0000000000000000000000000000000000000000..d1a89ff69cc6d2aabd893a9c67cd58045b4fe61f
--- /dev/null
+++ b/game/spawn/2/next_act.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba1b85fd4548f805352306a04e24368718328bfd513cb95305d0f9284fe9719f
+size 105728
diff --git a/game/spawn/3/act.npy b/game/spawn/3/act.npy
new file mode 100644
index 0000000000000000000000000000000000000000..795e24fb86302d7e4f38ef66d3a3ae87bd513fc6
--- /dev/null
+++ b/game/spawn/3/act.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60ff3ff40a48105e33feae08979dd4d7d7570984e10c950aae9308c08841400a
+size 10688
diff --git a/game/spawn/3/full_res.npy b/game/spawn/3/full_res.npy
new file mode 100644
index 0000000000000000000000000000000000000000..71f10372b77761a15339fb5f1758a10e59dbf86f
--- /dev/null
+++ b/game/spawn/3/full_res.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f1f8a9a266e05f8d8a741dadb697bf3ed89c1a166dab443c1d81091c3cf1824b
+size 22260128
diff --git a/game/spawn/3/low_res.npy b/game/spawn/3/low_res.npy
new file mode 100644
index 0000000000000000000000000000000000000000..de23d6a00891c2f79014c4981f6e2c1b528d1d19
--- /dev/null
+++ b/game/spawn/3/low_res.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0dd822892eb5201e23d5298da69523a4d43d740e37d7647d30a288a5b440991e
+size 368768
diff --git a/game/spawn/3/next_act.npy b/game/spawn/3/next_act.npy
new file mode 100644
index 0000000000000000000000000000000000000000..17f461d04b5278bad039cee29e2c1966a20b5528
--- /dev/null
+++ b/game/spawn/3/next_act.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40f2a5d1381a93fc46e58f7ac0bf5df9f507c5c3b6489e85a0896deba8da9dca
+size 105728
diff --git a/index.html b/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..90e4922b2881ce5d0f18ee2e4adad4ec10f623d3
--- /dev/null
+++ b/index.html
@@ -0,0 +1,928 @@
+
+
+
+
+
+ AI Game Multiverse
+
+
+
+
+
AI Game Multiverse
+
Play procedurally generated games using AI
+
+
+
+
+
+
+
Mouse: 0.00, 0.00
+
FPS: 0
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Keyboard Controls
+
+
+
+
+
+
W
+
+
+
A
+
S
+
D
+
+
+
SPACE
+
+
+
SHIFT
+
+
+
+ W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right
+ Space = Jump, Shift = Attack
+ Click on game view to capture mouse (ESC to release)
+ Mouse = Look around
+
+
+
+
+
+
+
+
Connection Log
+
+
+
+
+
Welcome to AI Game Multiverse. Click Connect to begin.
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/reference_example/Dockerfile b/reference_example/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..427f95c1602875b9f737a838afd29ae250e84831
--- /dev/null
+++ b/reference_example/Dockerfile
@@ -0,0 +1,52 @@
+FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+ENV PYTHONUNBUFFERED=1
+
+RUN apt-get update && apt-get install --no-install-recommends -y \
+ build-essential \
+ python3.11 \
+ python3-pip \
+ python3-dev \
+ git \
+ curl \
+ ffmpeg \
+ libglib2.0-0 \
+ libsm6 \
+ libxrender1 \
+ libxext6 \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+
+WORKDIR /code
+
+COPY ./requirements.txt /code/requirements.txt
+
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+# Switch to the "user" user
+USER user
+# Set home to the user's home directory
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:$PATH
+
+# Set home to the user's home directory
+ENV PYTHONPATH=$HOME/app \
+ PYTHONUNBUFFERED=1 \
+ DATA_ROOT=/tmp/data
+
+RUN echo "Installing requirements.txt"
+RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
+
+# yeah.. this is manual for now
+#RUN flutter build web
+
+WORKDIR $HOME/app
+
+COPY --chown=user . $HOME/app
+
+EXPOSE 8080
+
+ENV PORT 8080
+
+CMD python3 api.py
diff --git a/reference_example/api.py b/reference_example/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..5be3e44d40ab1e88c37f36332d9349e9ddde884b
--- /dev/null
+++ b/reference_example/api.py
@@ -0,0 +1,297 @@
+import asyncio
+import json
+import logging
+import os
+import pathlib
+import time
+import uuid
+from aiohttp import web, WSMsgType
+from typing import Dict, Any
+
+from api_core import VideoGenerationAPI
+from api_session import SessionManager
+from api_metrics import MetricsTracker
+from api_config import *
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+# Create global session and metrics managers
+session_manager = SessionManager()
+metrics_tracker = MetricsTracker()
+
+# Dictionary to track connected anonymous clients by IP address
+anon_connections = {}
+anon_connection_lock = asyncio.Lock()
+
+async def status_handler(request: web.Request) -> web.Response:
+ """Handler for API status endpoint"""
+ api = session_manager.shared_api
+
+ # Get current busy status of all endpoints
+ endpoint_statuses = []
+ for ep in api.endpoint_manager.endpoints:
+ endpoint_statuses.append({
+ 'id': ep.id,
+ 'url': ep.url,
+ 'busy': ep.busy,
+ 'last_used': ep.last_used,
+ 'error_count': ep.error_count,
+ 'error_until': ep.error_until
+ })
+
+ # Get session statistics
+ session_stats = session_manager.get_session_stats()
+
+ # Get metrics
+ api_metrics = metrics_tracker.get_metrics()
+
+ return web.json_response({
+ 'product': PRODUCT_NAME,
+ 'version': PRODUCT_VERSION,
+ 'maintenance_mode': MAINTENANCE_MODE,
+ 'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS),
+ 'endpoint_status': endpoint_statuses,
+ 'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())),
+ 'active_sessions': session_stats,
+ 'metrics': api_metrics
+ })
+
+async def metrics_handler(request: web.Request) -> web.Response:
+ """Handler for detailed metrics endpoint (protected)"""
+ # Check for API key in header or query param
+ auth_header = request.headers.get('Authorization', '')
+ api_key = None
+
+ if auth_header.startswith('Bearer '):
+ api_key = auth_header[7:]
+ else:
+ api_key = request.query.get('key')
+
+ # Validate API key (using SECRET_TOKEN as the API key)
+ if not api_key or api_key != SECRET_TOKEN:
+ return web.json_response({
+ 'error': 'Unauthorized'
+ }, status=401)
+
+ # Get detailed metrics
+ detailed_metrics = metrics_tracker.get_detailed_metrics()
+
+ return web.json_response(detailed_metrics)
+
+async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
+ # Check if maintenance mode is enabled
+ if MAINTENANCE_MODE:
+ # Return an error response indicating maintenance mode
+ return web.json_response({
+ 'error': 'Server is in maintenance mode',
+ 'maintenance': True
+ }, status=503) # 503 Service Unavailable
+
+ ws = web.WebSocketResponse(
+ max_msg_size=1024*1024*20, # 20MB max message size
+ timeout=30.0 # we want to keep things tight and short
+ )
+
+ await ws.prepare(request)
+
+ # Get the Hugging Face token from query parameters
+ hf_token = request.query.get('hf_token', '')
+
+ # Generate a unique user ID for this connection
+ user_id = str(uuid.uuid4())
+
+ # Validate the token and determine the user role
+ user_role = await session_manager.shared_api.validate_user_token(hf_token)
+ logger.info(f"User {user_id} connected with role: {user_role}")
+
+ # Get client IP address
+ peername = request.transport.get_extra_info('peername')
+ if peername is not None:
+ client_ip = peername[0]
+ else:
+ client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
+
+ logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}")
+
+ # Check for anonymous user connection limits
+ if user_role == 'anon':
+ async with anon_connection_lock:
+ # Track this connection
+ anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1
+ # Store the IP so we can clean up later
+ ws.client_ip = client_ip
+
+ # Log multiple connections from same IP but don't restrict them
+ if anon_connections[client_ip] > 1:
+ logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections")
+
+ # Store the user role in the websocket for easy access
+ ws.user_role = user_role
+ ws.user_id = user_id
+
+ # Register with metrics
+ metrics_tracker.register_session(user_id, client_ip)
+
+ # Create a new session for this user
+ user_session = await session_manager.create_session(user_id, user_role, ws)
+
+ try:
+ async for msg in ws:
+ if msg.type == WSMsgType.TEXT:
+ try:
+ data = json.loads(msg.data)
+ action = data.get('action')
+
+ # Check for rate limiting
+ request_type = 'other'
+ if action in ['join_chat', 'leave_chat', 'chat_message']:
+ request_type = 'chat'
+ elif action in ['generate_video']:
+ request_type = 'video'
+ elif action == 'search':
+ request_type = 'search'
+ elif action == 'simulate':
+ request_type = 'simulation'
+
+ # Record the request for metrics
+ await metrics_tracker.record_request(user_id, client_ip, request_type, user_role)
+
+ # Check rate limits (except for admins)
+ if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role):
+ await ws.send_json({
+ 'action': action,
+ 'requestId': data.get('requestId'),
+ 'success': False,
+ 'error': f'Rate limit exceeded for {request_type} requests. Please try again later.'
+ })
+ continue
+
+ # Route requests to appropriate queues
+ if action in ['join_chat', 'leave_chat', 'chat_message']:
+ await user_session.chat_queue.put(data)
+ elif action in ['generate_video']:
+ await user_session.video_queue.put(data)
+ elif action == 'search':
+ await user_session.search_queue.put(data)
+ elif action == 'simulate':
+ await user_session.simulation_queue.put(data)
+ else:
+ await user_session.process_generic_request(data)
+
+ except Exception as e:
+ logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
+ await ws.send_json({
+ 'action': data.get('action') if 'data' in locals() else 'unknown',
+ 'success': False,
+ 'error': f'Error processing message: {str(e)}'
+ })
+
+ elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
+ break
+
+ finally:
+ # Cleanup session
+ await session_manager.delete_session(user_id)
+
+ # Cleanup anonymous connection tracking
+ if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'):
+ client_ip = ws.client_ip
+ async with anon_connection_lock:
+ if client_ip in anon_connections:
+ anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1)
+ if anon_connections[client_ip] == 0:
+ del anon_connections[client_ip]
+ logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}")
+
+ # Unregister from metrics
+ metrics_tracker.unregister_session(user_id, client_ip)
+ logger.info(f"Connection closed for user {user_id}")
+
+ return ws
+
+async def init_app() -> web.Application:
+ app = web.Application(
+ client_max_size=1024**2*20 # 20MB max size
+ )
+
+ # Add cleanup logic
+ async def cleanup(app):
+ logger.info("Shutting down server, closing all sessions...")
+ await session_manager.close_all_sessions()
+
+ app.on_shutdown.append(cleanup)
+
+ # Add routes
+ app.router.add_get('/ws', websocket_handler)
+ app.router.add_get('/api/status', status_handler)
+ app.router.add_get('/api/metrics', metrics_handler)
+
+ # Set up static file serving
+ # Define the path to the public directory
+ public_path = pathlib.Path(__file__).parent / 'build' / 'web'
+ if not public_path.exists():
+ public_path.mkdir(parents=True, exist_ok=True)
+
+ # Set up static file serving with proper security considerations
+ async def static_file_handler(request):
+ # Get the path from the request (removing leading /)
+ path_parts = request.path.lstrip('/').split('/')
+
+ # Convert to safe path to prevent path traversal attacks
+ safe_path = public_path.joinpath(*path_parts)
+
+ # Make sure the path is within the public directory (prevent directory traversal)
+ try:
+ safe_path = safe_path.resolve()
+ if not str(safe_path).startswith(str(public_path.resolve())):
+ return web.HTTPForbidden(text="Access denied")
+ except (ValueError, FileNotFoundError):
+ return web.HTTPNotFound()
+
+ # If path is a directory, look for index.html
+ if safe_path.is_dir():
+ safe_path = safe_path / 'index.html'
+
+ # Check if the file exists
+ if not safe_path.exists() or not safe_path.is_file():
+ # If not found, serve index.html (for SPA routing)
+ safe_path = public_path / 'index.html'
+ if not safe_path.exists():
+ return web.HTTPNotFound()
+
+ # Determine content type based on file extension
+ content_type = 'text/plain'
+ ext = safe_path.suffix.lower()
+ if ext == '.html':
+ content_type = 'text/html'
+ elif ext == '.js':
+ content_type = 'application/javascript'
+ elif ext == '.css':
+ content_type = 'text/css'
+ elif ext in ('.jpg', '.jpeg'):
+ content_type = 'image/jpeg'
+ elif ext == '.png':
+ content_type = 'image/png'
+ elif ext == '.gif':
+ content_type = 'image/gif'
+ elif ext == '.svg':
+ content_type = 'image/svg+xml'
+ elif ext == '.json':
+ content_type = 'application/json'
+
+ # Return the file with appropriate headers
+ return web.FileResponse(safe_path, headers={'Content-Type': content_type})
+
+ # Add catch-all route for static files (lower priority than API routes)
+ app.router.add_get('/{path:.*}', static_file_handler)
+
+ return app
+
+if __name__ == '__main__':
+ app = asyncio.run(init_app())
+ web.run_app(app, host='0.0.0.0', port=8080)
\ No newline at end of file
diff --git a/reference_example/api_config.py b/reference_example/api_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b366493d35c59e610b47f8db2b84238e99b1771e
--- /dev/null
+++ b/reference_example/api_config.py
@@ -0,0 +1,184 @@
+import os
+
+PRODUCT_NAME = os.environ.get('PRODUCT_NAME', 'TikSlop')
+PRODUCT_VERSION = "2.0.0"
+
+# you should use Mistral 7b instruct for good performance and accuracy balance
+TEXT_MODEL = os.environ.get('HF_TEXT_MODEL', '')
+
+# Environment variable to control maintenance mode
+MAINTENANCE_MODE = os.environ.get('MAINTENANCE_MODE', 'false').lower() in ('true', 'yes', '1', 't')
+
+# Environment variable to control how many nodes to use
+MAX_NODES = int(os.environ.get('MAX_NODES', '8'))
+
+ADMIN_ACCOUNTS = [
+ "jbilcke-hf"
+]
+
+RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS = [
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_1', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_2', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_3', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_4', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_5', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_6', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_7', ''),
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_8', ''),
+]
+
+# Filter out empty strings from the endpoint list
+filtered_urls = [url for url in RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS if url]
+
+# Limit the number of URLs based on MAX_NODES environment variable
+VIDEO_ROUND_ROBIN_ENDPOINT_URLS = filtered_urls[:MAX_NODES]
+
+HF_TOKEN = os.environ.get('HF_TOKEN')
+
+# use the same secret token as you used to secure your BASE_SPACE_NAME spaces
+SECRET_TOKEN = os.environ.get('SECRET_TOKEN')
+
+# altenative words we could use: "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
+NEGATIVE_PROMPT = "low quality, worst quality, deformed, distorted, disfigured, blurry, text, watermark"
+
+POSITIVE_PROMPT_SUFFIX = "high quality, cinematic, 4K, intricate details"
+
+GUIDANCE_SCALE = 1.0
+
+THUMBNAIL_FRAMES = 65
+
+# anonymous users are people browing TikSlop without being connected
+# this category suffers from regular abuse so we need to enforce strict limitations
+CONFIG_FOR_ANONYMOUS_USERS = {
+
+ # anons can only watch 2 minutes per video
+ "max_rendering_time_per_client_per_video_in_sec": 2 * 60,
+
+ "min_num_inference_steps": 2,
+ "default_num_inference_steps": 4,
+ "max_num_inference_steps": 4,
+
+ "min_num_frames": 9, # 8 + 1
+ "default_max_num_frames": 65, # 8*8 + 1
+ "max_num_frames": 65, # 8*8 + 1
+
+ "min_clip_duration_seconds": 1,
+ "default_clip_duration_seconds": 2,
+ "max_clip_duration_seconds": 2,
+
+ "min_clip_playback_speed": 0.7,
+ "default_clip_playback_speed": 0.7,
+ "max_clip_playback_speed": 0.7,
+
+ "min_clip_framerate": 8,
+ "default_clip_framerate": 16,
+ "max_clip_framerate": 16,
+
+ "min_clip_width": 544,
+ "default_clip_width": 640,
+ "max_clip_width": 640,
+
+ "min_clip_height": 320,
+ "default_clip_height": 352,
+ "max_clip_height": 352,
+}
+
+# Hugging Face users enjoy a more normal and calibrated experience
+CONFIG_FOR_STANDARD_HF_USERS = {
+ "max_rendering_time_per_client_per_video_in_sec": 15 * 60,
+
+ "min_num_inference_steps": 2,
+ "default_num_inference_steps": 4,
+ "max_num_inference_steps": 4,
+
+ "min_num_frames": 9, # 8 + 1
+ "default_num_frames": 81, # 8*10 + 1
+ "max_num_frames": 81,
+
+ "min_clip_duration_seconds": 1,
+ "default_clip_duration_seconds": 3,
+ "max_clip_duration_seconds": 3,
+
+ "min_clip_playback_speed": 0.7,
+ "default_clip_playback_speed": 0.7,
+ "max_clip_playback_speed": 0.7,
+
+ "min_clip_framerate": 8,
+ "default_clip_framerate": 25,
+ "max_clip_framerate": 25,
+
+ "min_clip_width": 544,
+ "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
+ "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
+
+ "min_clip_height": 320,
+ "default_clip_height": 640, # 512, # 448, # 416,
+ "max_clip_height": 640, # 512, # 448, # 416,
+}
+
+# Hugging Face users with a Pro may enjoy an improved experience
+CONFIG_FOR_PRO_HF_USERS = {
+ "max_rendering_time_per_client_per_video_in_sec": 20 * 60,
+
+ "min_num_inference_steps": 2,
+ "default_num_inference_steps": 4,
+ "max_num_inference_steps": 4,
+
+ "min_num_frames": 9, # 8 + 1
+ "default_num_frames": 81, # 8*10 + 1
+ "max_num_frames": 81,
+
+ "min_clip_duration_seconds": 1,
+ "default_clip_duration_seconds": 3,
+ "max_clip_duration_seconds": 3,
+
+ "min_clip_playback_speed": 0.7,
+ "default_clip_playback_speed": 0.7,
+ "max_clip_playback_speed": 0.7,
+
+ "min_clip_framerate": 8,
+ "default_clip_framerate": 25,
+ "max_clip_framerate": 25,
+
+ "min_clip_width": 544,
+ "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
+ "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
+
+ "min_clip_height": 320,
+ "default_clip_height": 640, # 512, # 448, # 416,
+ "max_clip_height": 640, # 512, # 448, # 416,
+}
+
+CONFIG_FOR_ADMIN_HF_USERS = {
+ "max_rendering_time_per_client_per_video_in_sec": 60 * 60,
+
+ "min_num_inference_steps": 2,
+ "default_num_inference_steps": 4,
+ "max_num_inference_steps": 4,
+
+ "min_num_frames": 9, # 8 + 1
+ "default_num_frames": 81, # (8 * 10) + 1
+ "max_num_frames": 129, # (8 * 16) + 1
+
+ "min_clip_duration_seconds": 1,
+ "default_clip_duration_seconds": 2,
+ "max_clip_duration_seconds": 4,
+
+ "min_clip_playback_speed": 0.7,
+ "default_clip_playback_speed": 0.7,
+ "max_clip_playback_speed": 1.0,
+
+ "min_clip_framerate": 8,
+ "default_clip_framerate": 30,
+ "max_clip_framerate": 60,
+
+ "min_clip_width": 544,
+ "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
+ "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
+
+ "min_clip_height": 320,
+ "default_clip_height": 640, # 512, # 448, # 416,
+ "max_clip_height": 640, # 512, # 448, # 416,
+}
+
+CONFIG_FOR_ADMIN_HF_USERS = CONFIG_FOR_PRO_HF_USERS
\ No newline at end of file
diff --git a/reference_example/api_core.py b/reference_example/api_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..8066d154d3d041cc08b3f231ff9b0378b9b5424a
--- /dev/null
+++ b/reference_example/api_core.py
@@ -0,0 +1,1068 @@
+import logging
+import os
+import io
+import re
+import base64
+import uuid
+from typing import Dict, Any, Optional, List, Literal
+from dataclasses import dataclass
+from asyncio import Lock, Queue
+import asyncio
+import time
+import datetime
+from contextlib import asynccontextmanager
+from collections import defaultdict
+from aiohttp import web, ClientSession
+from huggingface_hub import InferenceClient, HfApi
+from gradio_client import Client
+import random
+import yaml
+import json
+
+from api_config import *
+
+# User role type
+UserRole = Literal['anon', 'normal', 'pro', 'admin']
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+
+def generate_seed():
+ """Generate a random positive 32-bit integer seed."""
+ return random.randint(0, 2**32 - 1)
+
+def sanitize_yaml_response(response_text: str) -> str:
+ """
+ Sanitize and format AI response into valid YAML.
+ Returns properly formatted YAML string.
+ """
+
+ response_text = response_text.split("```")[0]
+
+ # Remove any markdown code block indicators and YAML document markers
+ clean_text = re.sub(r'```yaml|```|---|\.\.\.$', '', response_text.strip())
+
+ # Split into lines and process each line
+ lines = clean_text.split('\n')
+ sanitized_lines = []
+ current_field = None
+
+ for line in lines:
+ stripped = line.strip()
+ if not stripped:
+ continue
+
+ # Handle field starts
+ if stripped.startswith('title:') or stripped.startswith('description:'):
+ # Ensure proper YAML format with space after colon and proper quoting
+ field_name = stripped.split(':', 1)[0]
+ field_value = stripped.split(':', 1)[1].strip().strip('"\'')
+
+ # Quote the value if it contains special characters
+ if any(c in field_value for c in ':[]{},&*#?|-<>=!%@`'):
+ field_value = f'"{field_value}"'
+
+ sanitized_lines.append(f"{field_name}: {field_value}")
+ current_field = field_name
+
+ elif stripped.startswith('tags:'):
+ sanitized_lines.append('tags:')
+ current_field = 'tags'
+
+ elif stripped.startswith('-') and current_field == 'tags':
+ # Process tag values
+ tag = stripped[1:].strip().strip('"\'')
+ if tag:
+ # Clean and format tag
+ tag = re.sub(r'[^\x00-\x7F]+', '', tag) # Remove non-ASCII
+ tag = re.sub(r'[^a-zA-Z0-9\s-]', '', tag) # Keep only alphanumeric and hyphen
+ tag = tag.strip().lower().replace(' ', '-')
+ if tag:
+ sanitized_lines.append(f" - {tag}")
+
+ elif current_field in ['title', 'description']:
+ # Handle multi-line title/description continuation
+ value = stripped.strip('"\'')
+ if value:
+ # Append to previous line
+ prev = sanitized_lines[-1]
+ sanitized_lines[-1] = f"{prev} {value}"
+
+ # Ensure the YAML has all required fields
+ required_fields = {'title', 'description', 'tags'}
+ found_fields = {line.split(':')[0].strip() for line in sanitized_lines if ':' in line}
+
+ for field in required_fields - found_fields:
+ if field == 'tags':
+ sanitized_lines.extend(['tags:', ' - default'])
+ else:
+ sanitized_lines.append(f'{field}: "No {field} provided"')
+
+ return '\n'.join(sanitized_lines)
+
+@dataclass
+class Endpoint:
+ id: int
+ url: str
+ busy: bool = False
+ last_used: float = 0
+ error_count: int = 0
+ error_until: float = 0 # Timestamp until which this endpoint is considered in error state
+
+class EndpointManager:
+ def __init__(self):
+ self.endpoints: List[Endpoint] = []
+ self.lock = Lock()
+ self.initialize_endpoints()
+ self.last_used_index = -1 # Track the last used endpoint for round-robin
+
+ def initialize_endpoints(self):
+ """Initialize the list of endpoints"""
+ for i, url in enumerate(VIDEO_ROUND_ROBIN_ENDPOINT_URLS):
+ endpoint = Endpoint(id=i + 1, url=url)
+ self.endpoints.append(endpoint)
+
+ def _get_next_free_endpoint(self):
+ """Get the next available non-busy endpoint, or oldest endpoint if all are busy"""
+ current_time = time.time()
+
+ # First priority: Get any non-busy and non-error endpoint
+ free_endpoints = [
+ ep for ep in self.endpoints
+ if not ep.busy and current_time > ep.error_until
+ ]
+
+ if free_endpoints:
+ # Return the least recently used free endpoint
+ return min(free_endpoints, key=lambda ep: ep.last_used)
+
+ # Second priority: If all busy/error, use round-robin but skip error endpoints
+ tried_count = 0
+ next_index = self.last_used_index
+
+ while tried_count < len(self.endpoints):
+ next_index = (next_index + 1) % len(self.endpoints)
+ tried_count += 1
+
+ # If endpoint is not in error state, use it
+ if current_time > self.endpoints[next_index].error_until:
+ self.last_used_index = next_index
+ return self.endpoints[next_index]
+
+ # If all endpoints are in error state, use the one with earliest error expiry
+ self.last_used_index = next_index
+ return min(self.endpoints, key=lambda ep: ep.error_until)
+
+ @asynccontextmanager
+ async def get_endpoint(self, max_wait_time: int = 10):
+ """Get the next available endpoint using a context manager"""
+ start_time = time.time()
+ endpoint = None
+
+ try:
+ while True:
+ if time.time() - start_time > max_wait_time:
+ raise TimeoutError(f"Could not acquire an endpoint within {max_wait_time} seconds")
+
+ async with self.lock:
+ # Get the next available endpoint using our selection strategy
+ endpoint = self._get_next_free_endpoint()
+
+ # Mark it as busy
+ endpoint.busy = True
+ endpoint.last_used = time.time()
+ #logger.info(f"Using endpoint {endpoint.id} (busy: {endpoint.busy}, last used: {endpoint.last_used})")
+ break
+
+ yield endpoint
+
+ finally:
+ if endpoint:
+ async with self.lock:
+ endpoint.busy = False
+ endpoint.last_used = time.time()
+ # We don't need to put back into queue - our strategy now picks directly from the list
+
+class ChatRoom:
+ def __init__(self):
+ self.messages = []
+ self.connected_clients = set()
+ self.max_history = 100
+
+ def add_message(self, message):
+ self.messages.append(message)
+ if len(self.messages) > self.max_history:
+ self.messages.pop(0)
+
+ def get_recent_messages(self, limit=50):
+ return self.messages[-limit:]
+
+class VideoGenerationAPI:
+ def __init__(self):
+ self.inference_client = InferenceClient(token=HF_TOKEN)
+ self.hf_api = HfApi(token=HF_TOKEN)
+ self.endpoint_manager = EndpointManager()
+ self.active_requests: Dict[str, asyncio.Future] = {}
+ self.chat_rooms = defaultdict(ChatRoom)
+ self.video_events: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
+ self.event_history_limit = 50
+ # Cache for user roles to avoid repeated API calls
+ self.user_role_cache: Dict[str, Dict[str, Any]] = {}
+ # Cache expiration time (10 minutes)
+ self.cache_expiration = 600
+
+
+ def _add_event(self, video_id: str, event: Dict[str, Any]):
+ """Add an event to the video's history and maintain the size limit"""
+ events = self.video_events[video_id]
+ events.append(event)
+ if len(events) > self.event_history_limit:
+ events.pop(0)
+
+ async def validate_user_token(self, token: str) -> UserRole:
+ """
+ Validates a Hugging Face token and determines the user's role.
+
+ Returns one of:
+ - 'anon': Anonymous user (no token or invalid token)
+ - 'normal': Standard Hugging Face user
+ - 'pro': Hugging Face Pro user
+ - 'admin': Admin user (username in ADMIN_ACCOUNTS)
+ """
+ # If no token is provided, the user is anonymous
+ if not token:
+ return 'anon'
+
+ # Check if we have a cached result for this token
+ current_time = time.time()
+ if token in self.user_role_cache:
+ cached_data = self.user_role_cache[token]
+ # If the cache is still valid
+ if current_time - cached_data['timestamp'] < self.cache_expiration:
+ logger.info(f"Using cached user role: {cached_data['role']}")
+ return cached_data['role']
+
+ # No valid cache, need to check the token with the HF API
+ try:
+ # Use HF API to validate the token and get user info
+ logger.info("Validating Hugging Face token...")
+
+ # Run in executor to avoid blocking the event loop
+ user_info = await asyncio.get_event_loop().run_in_executor(
+ None,
+ lambda: self.hf_api.whoami(token=token)
+ )
+
+ # Handle both object and dict response formats from whoami
+ username = user_info.get('name') if isinstance(user_info, dict) else getattr(user_info, 'name', None)
+ is_pro = user_info.get('is_pro') if isinstance(user_info, dict) else getattr(user_info, 'is_pro', False)
+
+ if not username:
+ logger.error(f"Could not determine username from user_info: {user_info}")
+ return 'anon'
+
+ logger.info(f"Token valid for user: {username}")
+
+ # Determine the user role based on the information
+ user_role: UserRole
+
+ # Check if the user is an admin
+ if username in ADMIN_ACCOUNTS:
+ user_role = 'admin'
+ # Check if the user has a pro account
+ elif is_pro:
+ user_role = 'pro'
+ else:
+ user_role = 'normal'
+
+ # Cache the result
+ self.user_role_cache[token] = {
+ 'role': user_role,
+ 'timestamp': current_time,
+ 'username': username
+ }
+
+ return user_role
+
+ except Exception as e:
+ logger.error(f"Failed to validate Hugging Face token: {str(e)}")
+ # If validation fails, the user is treated as anonymous
+ return 'anon'
+
+ async def download_video(self, url: str) -> bytes:
+ """Download video file from URL and return bytes"""
+ async with ClientSession() as session:
+ async with session.get(url) as response:
+ if response.status != 200:
+ raise Exception(f"Failed to download video: HTTP {response.status}")
+ return await response.read()
+
+ async def search_video(self, query: str, attempt_count: int = 0) -> Optional[dict]:
+ """Generate a single search result using HF text generation"""
+ # Maximum number of attempts to generate a description without placeholder tags
+ max_attempts = 2
+ current_attempt = attempt_count
+ # Use a random temperature between 0.68 and 0.72 to generate more diverse results
+ # and prevent duplicate results from successive calls with the same prompt
+ temperature = random.uniform(0.68, 0.72)
+
+ while current_attempt <= max_attempts:
+ prompt = f"""# Instruction
+Your response MUST be a YAML object containing a title and description, consistent with what we can find on a video sharing platform.
+Format your YAML response with only those fields: "title" (a short string) and "description" (string caption of the scene). Do not add any other field.
+In the description field, describe in a very synthetic way the visuals of the first shot (first scene), eg "