diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..3152023feaef0811cf55ea93ff7d32c71a004d03 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,45 @@ +FROM nvidia/cuda:12.0.1-runtime-ubuntu22.04 + +# Set environment variables +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + ffmpeg \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libglib2.0-0 \ + git \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /app + +# Copy requirements first to leverage Docker caching +COPY requirements.txt . + +# Install Python dependencies +RUN pip3 install --no-cache-dir -r requirements.txt +RUN pip3 install aiohttp + +# Install additional required packages +RUN pip3 install --no-cache-dir torch torchvision torchaudio + +# Copy application code +COPY . . + +# Create assets directory if it doesn't exist +RUN mkdir -p /app/assets + +# Expose the port used by the server +EXPOSE 8080 + +# Set entry command +CMD ["python3", "server.py", "--host", "0.0.0.0", "--port", "8080"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..08357e62af5bcd3f43c8028206d75238c4db1655 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2024 Eloi Alonso +Copyright (c) 2025 Enigma Labs AI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4a1d97e427779dab9d5035c290b67dfaccc18b6d..212b895a1cb25e646708ed05f0bab4bf57d49284 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,170 @@ --- -title: Tikslop Gaming Multiverse -emoji: 🏃 -colorFrom: purple -colorTo: purple +title: Multiverse +emoji: 🐟 +colorFrom: blue +colorTo: blue sdk: docker -pinned: false +app_file: server.py +pinned: true +short_description: AI Multiplayer World Model +app_port: 8080 +disable_embedding: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Multiverse: The First AI Multiplayer World Model + +🌐 [Enigma-AI website](https://enigma-labs.io/) - 📚 [Technical Blog](https://enigma-labs.io/) - [🤗 Model on Huggingface](https://huggingface.co/Enigma-AI/multiverse) - [🤗 Datasets on Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res) - 𝕏 [Multiverse Tweet](https://x.com/j0nathanj/status/1920516649511244258) + +
+ Two human players driving cars in Multiverse +
+ Cars in Multiverse +
+ +--- + +## Installation +```bash +git clone https://github.com/EnigmaLabsAI/multiverse +cd multiverse +pip install -r requirements.txt +``` + +### Running the model + +```bash +python src/play.py --compile +``` + +> Note on Apple Silicon you must enable CPU fallback for MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1 python src/play.py + +When running this command, you will be prompted with the controls. Press `enter` to start: +![img.png](assets/img.png) + +Then the game will be start: +* To control the silver car at the top screen use the arrow keys. +* To control the blue car at the bottom use the WASD keys. + +![img_2.png](assets/img_2.png) + +--- + + +## Training + +Multiverse comprised two models: +* Denoiser - a world model that simulates a game +* Upsampler - a model which takes the frames from the denoiser and increases their resolution + +### Denoiser training + +#### 1. Download the dataset +Download the Denoiser's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res). + +#### 2. Process data for training +Run the command: +```bash +python src/process_denoiser_files.py +``` + +#### 3. Edit training configuration + +Edit [config/env/racing.yaml](config/env/racing.yaml) and set: +- `path_data_low_res` to `/low_res` +- `path_data_full_res` to `/full_res` + +Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`: +```yaml +train_model: denoiser +``` + +#### 4. Launch training run + +You can then launch a training run with `python src/main.py`. + + +### Upsampler training + +#### 1. Download the dataset +Download the Upsampler's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res). + +#### 2. Process data for training +Run the command: +```bash +python src/process_upsampler_files.py +``` + +#### 3. Edit training configuration + +Edit [config/env/racing.yaml](config/env/racing.yaml) and set: +- `path_data_low_res` to `/low_res` +- `path_data_full_res` to `/full_res` + +Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`: +```yaml +train_model: upsampler +``` + +#### 4. Launch training run + +You can then launch a training run with `python src/main.py`. + + +--- + +## Datasets + +1. We've collected over 4 hours of multiplayer (1v1) footage from Gran Turismo 4 at a resolution of 48x64 (per players): [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res). + +2. A sparse sampling of full resolution, cropped frames, are availabe in order to train the upsampler at a resolution of 350x530: [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res). + +The datasets contain a variety of situations: acceleration, braking, overtakes, crashes, and expert driving for both players. +You can read about the data collection mechanism [here](https://enigma-labs.io/blog) + +Note: The full resolution dataset is only for upsampler training and is not fit for world model training. + +--- + +## Outside resources + +- DIAMOND - https://github.com/eloialonso/diamond +- AI-MarioKart64 - https://github.com/Dere-Wah/AI-MarioKart64 + +--- + +## Cloud Gaming Server + +This project now includes a WebSocket-based cloud gaming server that allows you to play the game through a web browser. + +### Using Docker (Recommended for GPU Servers) + +The easiest way to deploy the cloud gaming server on a machine with an NVIDIA GPU is using Docker: + +```bash +# Build the Docker image +docker build -t ai-game-multiverse . + +# Run the container with GPU support +docker run --gpus all -p 8080:8080 ai-game-multiverse +``` + +Then access the web interface at http://yourserver:8080 + +### Features + +- Web-based interface accessible from any modern browser +- Real-time streaming of AI-generated game frames +- Keyboard and mouse controls +- Multiple scene selection +- WebSocket communication for low-latency interaction + +### Usage + +1. Access the web interface at http://yourserver:8080 +2. Click "Connect" to establish a WebSocket connection +3. Select a scene from the dropdown +4. Click "Start Stream" to begin streaming frames +5. Use WASD keys for movement, Space for jump, Shift for attack +6. Mouse controls camera view (click on game area to capture mouse) + +Note: The server requires an NVIDIA GPU for optimal performance with the AI models. Without a suitable GPU, it will fall back to using simple placeholder frames. \ No newline at end of file diff --git a/config/agent/racing.yaml b/config/agent/racing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..470a4f2768357c5591005fb8e7b26ac2f8b0a14a --- /dev/null +++ b/config/agent/racing.yaml @@ -0,0 +1,66 @@ +_target_: agent.AgentConfig + +denoiser: + _target_: models.diffusion.DenoiserConfig + sigma_data: 0.5 + sigma_offset_noise: 0.1 + noise_previous_obs: true + upsampling_factor: null + frame_sampling: + - count: 4 + stride: 1 + - count: 4 + stride: 4 + inner_model: + _target_: models.diffusion.InnerModelConfig + img_channels: 6 + num_steps_conditioning: 8 + cond_channels: 2048 + depths: + - 2 + - 2 + - 2 + - 2 + channels: + - 128 + - 256 + - 512 + - 1024 + attn_depths: + - 0 + - 0 + - 1 + - 1 + +upsampler: + _target_: models.diffusion.DenoiserConfig + sigma_data: 0.5 + sigma_offset_noise: 0.1 + noise_previous_obs: false + upsampling_factor: 10 + upsampling_frame_height: 350 + upsampling_frame_width: 530 + inner_model: + _target_: models.diffusion.InnerModelConfig + img_channels: 6 + num_steps_conditioning: 0 + cond_channels: 2048 + depths: + - 2 + - 2 + - 2 + - 2 + channels: + - 64 + - 64 + - 128 + - 256 + attn_depths: + - 0 + - 0 + - 0 + - 0 + +rew_end_model: null + +actor_critic: null diff --git a/config/env/racing.yaml b/config/env/racing.yaml new file mode 100644 index 0000000000000000000000000000000000000000..79c035cfe6d02d3dec551bea230fcdcf5f75f842 --- /dev/null +++ b/config/env/racing.yaml @@ -0,0 +1,7 @@ +train: + id: racing + size: [700, 530] +num_actions: 66 +path_data_low_res: null +path_data_full_res: null +keymap: racing diff --git a/config/trainer.yaml b/config/trainer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4217362d937e42b6144e4a624bdd2c5e3a5a125f --- /dev/null +++ b/config/trainer.yaml @@ -0,0 +1,113 @@ +defaults: + - _self_ + - env: racing + - agent: racing + - world_model_env: fast + +hydra: + job: + chdir: True + +wandb: + mode: offline + project: null + entity: null + name: null + group: null + tags: null + +initialization: + path_to_ckpt: null + load_denoiser: True + load_rew_end_model: True + load_actor_critic: True + +common: + devices: all # int, list of int, cpu, or all + seed: null + resume: False # do not modify, set by scripts/resume.sh only. + +checkpointing: + save_agent_every: 5 + num_to_keep: 11 # number of checkpoints to keep, use null to disable + +collection: + train: + num_envs: 1 + epsilon: 0.01 + num_steps_total: 100000 + first_epoch: + min: 5000 + max: 10000 # null: no maximum + threshold_rew: 10 + steps_per_epoch: 100 + test: + num_envs: 1 + num_episodes: 4 + epsilon: 0.0 + num_final_episodes: 100 + +static_dataset: + path: ${env.path_data_low_res} + ignore_sample_weights: True + +training: + should: True + num_final_epochs: 600 + cache_in_ram: False + num_workers_data_loaders: 1 + model_free: False # if True, turn off world_model training and RL in imagination + compile_wm: False + +evaluation: + should: True + every: 20 + +train_model: denoiser + +denoiser: + training: + num_autoregressive_steps: 8 + initial_num_consecutive_page_count: 1 + num_consecutive_pages: + - epoch: 400 + count: 10 + - epoch: 500 + count: 50 + start_after_epochs: 0 + steps_first_epoch: 10 + steps_per_epoch: 20 + sample_weights: null + batch_size: 30 + grad_acc_steps: 2 + lr_warmup_steps: 100 + max_grad_norm: 10.0 + + optimizer: + lr: 1e-4 + weight_decay: 1e-2 + eps: 1e-8 + + sigma_distribution: # log normal distribution for sigma during training + _target_: models.diffusion.SigmaDistributionConfig + loc: -1.2 + scale: 1.2 + sigma_min: 2e-3 + sigma_max: 20 + +upsampler: + training: + num_autoregressive_steps: 1 + initial_num_consecutive_page_count: 1 + start_after_epochs: 0 + steps_first_epoch: 20 + steps_per_epoch: 20 + sample_weights: null + batch_size: 4 + grad_acc_steps: 2 + lr_warmup_steps: 100 + max_grad_norm: 10.0 + + optimizer: ${denoiser.optimizer} + sigma_distribution: ${denoiser.sigma_distribution} + diff --git a/config/world_model_env/fast.yaml b/config/world_model_env/fast.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08eadc86c1cbe70f1ab822844f1194c9319d8120 --- /dev/null +++ b/config/world_model_env/fast.yaml @@ -0,0 +1,27 @@ +_target_: envs.WorldModelEnvConfig +horizon: 1000 +num_batches_to_preload: 256 +diffusion_sampler_next_obs: + _target_: models.diffusion.DiffusionSamplerConfig + num_steps_denoising: 1 + sigma_min: 2e-3 + sigma_max: 5.0 + rho: 7 + order: 1 # 1: Euler, 2: Heun + s_churn: 0.0 # Amount of stochasticity + s_tmin: 0.0 + s_tmax: ${eval:'float("inf")'} + s_noise: 1.0 + s_cond: 0.005 +diffusion_sampler_upsampling: + _target_: models.diffusion.DiffusionSamplerConfig + num_steps_denoising: 1 + sigma_min: 1 + sigma_max: 5.0 + rho: 7 + order: 2 # 1: Euler, 2: Heun + s_churn: 10.0 # Amount of stochasticity + s_tmin: 1 + s_tmax: 5 + s_noise: 0.9 + s_cond: 0 \ No newline at end of file diff --git a/example/Dockerfile b/example/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..28542909b9f5203a3d24ebe0234abc82965b4486 --- /dev/null +++ b/example/Dockerfile @@ -0,0 +1,59 @@ +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 + +ARG DEBIAN_FRONTEND=noninteractive + +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update && apt-get install --no-install-recommends -y \ + build-essential \ + python3.11 \ + python3-pip \ + python3-dev \ + git \ + curl \ + ffmpeg \ + libglib2.0-0 \ + libsm6 \ + libxrender1 \ + libxext6 \ + ninja-build \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +WORKDIR /code + +COPY ./requirements.txt /code/requirements.txt + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user +# Switch to the "user" user +USER user +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +# Set Python path and environment variables +ENV PYTHONPATH=$HOME/app \ + PYTHONUNBUFFERED=1 \ + DATA_ROOT=/tmp/data + +RUN echo "Installing requirements.txt" +RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt + +# Install NVIDIA Apex with CUDA and C++ extensions +RUN cd $HOME && \ + git clone https://github.com/NVIDIA/apex && \ + cd apex && \ + NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--parallel" --global-option="8" ./ + +WORKDIR $HOME/app + +# Copy all files and set proper ownership +COPY --chown=user . $HOME/app + +# Expose the port that server.py uses (8080) +EXPOSE 8080 + +ENV PORT 8080 + +# Run the HF space launcher script which sets up the correct paths +CMD ["python3", "run_hf_space.py"] \ No newline at end of file diff --git a/example/client.js b/example/client.js new file mode 100644 index 0000000000000000000000000000000000000000..36724958f7e64964cc73355a3d7908abb7a2abe0 --- /dev/null +++ b/example/client.js @@ -0,0 +1,603 @@ +// MatrixGame WebSocket Client + +// WebSocket connection +let socket = null; +let userId = null; +let isStreaming = false; +let lastFrameTime = 0; +let frameCount = 0; +let fpsUpdateInterval = null; + +// DOM Elements +const connectBtn = document.getElementById('connect-btn'); +const startStreamBtn = document.getElementById('start-stream-btn'); +const stopStreamBtn = document.getElementById('stop-stream-btn'); +const sceneSelect = document.getElementById('scene-select'); +const gameCanvas = document.getElementById('game-canvas'); +const connectionLog = document.getElementById('connection-log'); +const mousePosition = document.getElementById('mouse-position'); +const fpsCounter = document.getElementById('fps-counter'); +const mouseTrackingArea = document.getElementById('mouse-tracking-area'); + +// Pointer Lock API support check +const pointerLockSupported = 'pointerLockElement' in document || + 'mozPointerLockElement' in document || + 'webkitPointerLockElement' in document; + +// Keyboard DOM elements +const keyElements = { + 'w': document.getElementById('key-w'), + 'a': document.getElementById('key-a'), + 's': document.getElementById('key-s'), + 'd': document.getElementById('key-d'), + 'space': document.getElementById('key-space'), + 'shift': document.getElementById('key-shift') +}; + +// Key mapping to action names +const keyToAction = { + 'w': 'forward', + 'arrowup': 'forward', + 'a': 'left', + 'arrowleft': 'left', + 's': 'back', + 'arrowdown': 'back', + 'd': 'right', + 'arrowright': 'right', + ' ': 'jump', + 'shift': 'attack' +}; + +// Key state tracking +const keyState = { + 'forward': false, + 'back': false, + 'left': false, + 'right': false, + 'jump': false, + 'attack': false +}; + +// Mouse state +const mouseState = { + x: 0, + y: 0, + captured: false +}; + +// Test server connectivity before establishing WebSocket +async function testServerConnectivity() { + try { + // Get base path by extracting path from the script tag's src attribute + let basePath = ''; + const scriptTags = document.getElementsByTagName('script'); + for (const script of scriptTags) { + if (script.src.includes('client.js')) { + const url = new URL(script.src); + basePath = url.pathname.replace('/assets/client.js', ''); + break; + } + } + + // Try to fetch the debug endpoint to see if the server is accessible + const response = await fetch(`${window.location.protocol}//${window.location.host}${basePath}/api/debug`); + if (!response.ok) { + throw new Error(`Server returned ${response.status}`); + } + + const debugInfo = await response.json(); + logMessage(`Server connection test successful! Server time: ${new Date(debugInfo.server_time * 1000).toLocaleTimeString()}`); + + // Log available routes from server + if (debugInfo.all_routes && debugInfo.all_routes.length > 0) { + logMessage(`Available routes: ${debugInfo.all_routes.join(', ')}`); + } + + // Return the debug info for connection setup + return debugInfo; + } catch (error) { + logMessage(`Server connection test failed: ${error.message}`); + return null; + } +} + +// Connect to WebSocket server +async function connectWebSocket() { + // First test connectivity to the server + logMessage('Testing server connectivity...'); + const debugInfo = await testServerConnectivity(); + + // Use secure WebSocket (wss://) if the page is loaded over HTTPS + const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + + // Get base path by extracting path from the script tag's src attribute + let basePath = ''; + if (debugInfo && debugInfo.base_path) { + // Use base path from server if available + basePath = debugInfo.base_path; + logMessage(`Using server-provided base path: ${basePath}`); + } else { + const scriptTags = document.getElementsByTagName('script'); + for (const script of scriptTags) { + if (script.src.includes('client.js')) { + const url = new URL(script.src); + basePath = url.pathname.replace('/assets/client.js', ''); + break; + } + } + } + + // Try both with and without base path for WebSocket connection + let serverUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}${basePath}/ws`; + logMessage(`Attempting to connect to WebSocket at ${serverUrl}...`); + + // For Hugging Face Spaces, try the direct /ws path if the base path doesn't work + const fallbackUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}/ws`; + + try { + socket = new WebSocket(serverUrl); + setupWebSocketHandlers(); + + // Set a timeout to try the fallback URL if the first one doesn't connect + setTimeout(() => { + if (socket.readyState !== WebSocket.OPEN && socket.readyState !== WebSocket.CONNECTING) { + logMessage(`Connection to ${serverUrl} failed. Trying fallback URL: ${fallbackUrl}`); + socket = new WebSocket(fallbackUrl); + setupWebSocketHandlers(); + } + }, 3000); + } catch (error) { + logMessage(`Error connecting to WebSocket: ${error.message}`); + resetUI(); + } +} + +// Set up WebSocket event handlers +function setupWebSocketHandlers() { + socket.onopen = () => { + logMessage('WebSocket connection established'); + connectBtn.textContent = 'Disconnect'; + startStreamBtn.disabled = false; + sceneSelect.disabled = false; + }; + + socket.onmessage = (event) => { + const message = JSON.parse(event.data); + + switch (message.action) { + case 'welcome': + userId = message.userId; + logMessage(`Connected with user ID: ${userId}`); + + // Update scene options if server provides them + if (message.scenes && Array.isArray(message.scenes)) { + sceneSelect.innerHTML = ''; + message.scenes.forEach(scene => { + const option = document.createElement('option'); + option.value = scene; + option.textContent = scene.charAt(0).toUpperCase() + scene.slice(1); + sceneSelect.appendChild(option); + }); + } + break; + + case 'frame': + // Process incoming frame + processFrame(message); + break; + + case 'start_stream': + if (message.success) { + isStreaming = true; + startStreamBtn.disabled = true; + stopStreamBtn.disabled = false; + logMessage(`Streaming started: ${message.message}`); + + // Start FPS counter + startFpsCounter(); + } else { + logMessage(`Error starting stream: ${message.error}`); + } + break; + + case 'stop_stream': + if (message.success) { + isStreaming = false; + startStreamBtn.disabled = false; + stopStreamBtn.disabled = true; + logMessage('Streaming stopped'); + + // Stop FPS counter + stopFpsCounter(); + } else { + logMessage(`Error stopping stream: ${message.error}`); + } + break; + + case 'pong': + // Server responded to ping + break; + + case 'change_scene': + if (message.success) { + logMessage(`Scene changed to ${message.scene}`); + } else { + logMessage(`Error changing scene: ${message.error}`); + } + break; + + default: + logMessage(`Received message: ${JSON.stringify(message)}`); + } + }; + + socket.onclose = (event) => { + logMessage(`WebSocket connection closed (code: ${event.code}, reason: ${event.reason || 'none given'})`); + resetUI(); + }; + + socket.onerror = (error) => { + logMessage(`WebSocket error. This is often caused by CORS issues or the server being inaccessible.`); + console.error('WebSocket error:', error); + resetUI(); + }; +} + +// Disconnect from WebSocket server +function disconnectWebSocket() { + if (socket && socket.readyState === WebSocket.OPEN) { + // Stop streaming if active + if (isStreaming) { + sendStopStream(); + } + + // Close the socket + socket.close(); + logMessage('Disconnected from server'); + } +} + +// Start streaming frames +function sendStartStream() { + if (socket && socket.readyState === WebSocket.OPEN) { + socket.send(JSON.stringify({ + action: 'start_stream', + requestId: generateRequestId(), + fps: 16 // Default FPS + })); + } +} + +// Stop streaming frames +function sendStopStream() { + if (socket && socket.readyState === WebSocket.OPEN) { + socket.send(JSON.stringify({ + action: 'stop_stream', + requestId: generateRequestId() + })); + } +} + +// Send keyboard input to server +function sendKeyboardInput(key, pressed) { + if (socket && socket.readyState === WebSocket.OPEN) { + socket.send(JSON.stringify({ + action: 'keyboard_input', + requestId: generateRequestId(), + key: key, + pressed: pressed + })); + } +} + +// Send mouse input to server +function sendMouseInput(x, y) { + if (socket && socket.readyState === WebSocket.OPEN && isStreaming) { + socket.send(JSON.stringify({ + action: 'mouse_input', + requestId: generateRequestId(), + x: x, + y: y + })); + } +} + +// Change scene +function sendChangeScene(scene) { + if (socket && socket.readyState === WebSocket.OPEN) { + socket.send(JSON.stringify({ + action: 'change_scene', + requestId: generateRequestId(), + scene: scene + })); + } +} + +// Process incoming frame +function processFrame(message) { + // Update FPS calculation + const now = performance.now(); + if (lastFrameTime > 0) { + frameCount++; + } + lastFrameTime = now; + + // Update the canvas with the new frame + if (message.frameData) { + gameCanvas.src = `data:image/jpeg;base64,${message.frameData}`; + } +} + +// Generate a random request ID +function generateRequestId() { + return Math.random().toString(36).substring(2, 15); +} + +// Log message to the connection info panel +function logMessage(message) { + const logEntry = document.createElement('div'); + logEntry.className = 'log-entry'; + + const timestamp = new Date().toLocaleTimeString(); + logEntry.textContent = `[${timestamp}] ${message}`; + + connectionLog.appendChild(logEntry); + connectionLog.scrollTop = connectionLog.scrollHeight; + + // Limit number of log entries + while (connectionLog.children.length > 100) { + connectionLog.removeChild(connectionLog.firstChild); + } +} + +// Start FPS counter updates +function startFpsCounter() { + frameCount = 0; + lastFrameTime = 0; + + // Update FPS display every second + fpsUpdateInterval = setInterval(() => { + fpsCounter.textContent = `FPS: ${frameCount}`; + frameCount = 0; + }, 1000); +} + +// Stop FPS counter updates +function stopFpsCounter() { + if (fpsUpdateInterval) { + clearInterval(fpsUpdateInterval); + fpsUpdateInterval = null; + } + fpsCounter.textContent = 'FPS: 0'; +} + +// Reset UI to initial state +function resetUI() { + connectBtn.textContent = 'Connect'; + startStreamBtn.disabled = true; + stopStreamBtn.disabled = true; + sceneSelect.disabled = true; + + // Reset key indicators + for (const key in keyElements) { + keyElements[key].classList.remove('active'); + } + + // Stop FPS counter + stopFpsCounter(); + + // Reset streaming state + isStreaming = false; +} + +// Event Listeners +connectBtn.addEventListener('click', () => { + if (socket && socket.readyState === WebSocket.OPEN) { + disconnectWebSocket(); + } else { + connectWebSocket(); + } +}); + +startStreamBtn.addEventListener('click', sendStartStream); +stopStreamBtn.addEventListener('click', sendStopStream); + +sceneSelect.addEventListener('change', () => { + sendChangeScene(sceneSelect.value); +}); + +// Keyboard event listeners +document.addEventListener('keydown', (event) => { + const key = event.key.toLowerCase(); + + // Map key to action + let action = keyToAction[key]; + if (!action && key === ' ') { + action = keyToAction[' ']; // Handle spacebar + } + + if (action && !keyState[action]) { + keyState[action] = true; + + // Update visual indicator + const keyElement = keyElements[key] || + (key === ' ' ? keyElements['space'] : null) || + (key === 'shift' ? keyElements['shift'] : null); + + if (keyElement) { + keyElement.classList.add('active'); + } + + // Send to server + sendKeyboardInput(action, true); + } + + // Prevent default actions for game controls + if (Object.keys(keyToAction).includes(key) || key === ' ') { + event.preventDefault(); + } +}); + +document.addEventListener('keyup', (event) => { + const key = event.key.toLowerCase(); + + // Map key to action + let action = keyToAction[key]; + if (!action && key === ' ') { + action = keyToAction[' ']; // Handle spacebar + } + + if (action && keyState[action]) { + keyState[action] = false; + + // Update visual indicator + const keyElement = keyElements[key] || + (key === ' ' ? keyElements['space'] : null) || + (key === 'shift' ? keyElements['shift'] : null); + + if (keyElement) { + keyElement.classList.remove('active'); + } + + // Send to server + sendKeyboardInput(action, false); + } +}); + +// Mouse capture functions +function requestPointerLock() { + if (!mouseState.captured && pointerLockSupported) { + mouseTrackingArea.requestPointerLock = mouseTrackingArea.requestPointerLock || + mouseTrackingArea.mozRequestPointerLock || + mouseTrackingArea.webkitRequestPointerLock; + mouseTrackingArea.requestPointerLock(); + logMessage('Mouse captured. Press ESC to release.'); + } +} + +function exitPointerLock() { + if (mouseState.captured) { + document.exitPointerLock = document.exitPointerLock || + document.mozExitPointerLock || + document.webkitExitPointerLock; + document.exitPointerLock(); + logMessage('Mouse released.'); + } +} + +// Handle pointer lock change events +document.addEventListener('pointerlockchange', pointerLockChangeHandler); +document.addEventListener('mozpointerlockchange', pointerLockChangeHandler); +document.addEventListener('webkitpointerlockchange', pointerLockChangeHandler); + +function pointerLockChangeHandler() { + if (document.pointerLockElement === mouseTrackingArea || + document.mozPointerLockElement === mouseTrackingArea || + document.webkitPointerLockElement === mouseTrackingArea) { + // Pointer is locked, enable mouse movement tracking + mouseState.captured = true; + document.addEventListener('mousemove', handleMouseMovement); + } else { + // Pointer is unlocked, disable mouse movement tracking + mouseState.captured = false; + document.removeEventListener('mousemove', handleMouseMovement); + // Reset mouse state + mouseState.x = 0; + mouseState.y = 0; + mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`; + throttledSendMouseInput(); + } +} + +// Mouse tracking with pointer lock +function handleMouseMovement(event) { + if (mouseState.captured) { + // Use movement for mouse look when captured + const sensitivity = 0.005; // Adjust sensitivity + mouseState.x += event.movementX * sensitivity; + mouseState.y -= event.movementY * sensitivity; // Invert Y for intuitive camera control + + // Clamp values + mouseState.x = Math.max(-1, Math.min(1, mouseState.x)); + mouseState.y = Math.max(-1, Math.min(1, mouseState.y)); + + // Update display + mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`; + + // Send to server (throttled) + throttledSendMouseInput(); + } +} + +// Mouse click to capture +mouseTrackingArea.addEventListener('click', () => { + if (!mouseState.captured && isStreaming) { + requestPointerLock(); + } +}); + +// Standard mouse tracking for when pointer is not locked +mouseTrackingArea.addEventListener('mousemove', (event) => { + if (!mouseState.captured) { + // Calculate normalized coordinates relative to the center of the tracking area + const rect = mouseTrackingArea.getBoundingClientRect(); + const centerX = rect.width / 2; + const centerY = rect.height / 2; + + // Calculate relative position from center (-1 to 1) + const relX = (event.clientX - rect.left - centerX) / centerX; + const relY = (event.clientY - rect.top - centerY) / centerY; + + // Scale down for smoother movement (similar to conditions.py) + const scaleFactor = 0.05; + mouseState.x = relX * scaleFactor; + mouseState.y = -relY * scaleFactor; // Invert Y for intuitive camera control + + // Update display + mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`; + + // Send to server (throttled) + throttledSendMouseInput(); + } +}); + +// Throttle mouse movement to avoid flooding the server +const throttledSendMouseInput = (() => { + let lastSentTime = 0; + const interval = 50; // milliseconds + + return () => { + const now = performance.now(); + if (now - lastSentTime >= interval) { + sendMouseInput(mouseState.x, mouseState.y); + lastSentTime = now; + } + }; +})(); + +// Toggle panel collapse/expand +function togglePanel(panelId) { + const panel = document.getElementById(panelId); + const button = panel.querySelector('.toggle-button'); + + if (panel.classList.contains('collapsed')) { + // Expand the panel + panel.classList.remove('collapsed'); + button.textContent = '−'; // Minus sign + } else { + // Collapse the panel + panel.classList.add('collapsed'); + button.textContent = '+'; // Plus sign + } +} + +// Initialize the UI +resetUI(); + +// Make panel headers clickable +document.querySelectorAll('.panel-header').forEach(header => { + header.addEventListener('click', () => { + const panelId = header.parentElement.id; + togglePanel(panelId); + }); +}); \ No newline at end of file diff --git a/example/engine.py b/example/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5f724a95a78a72848435427171ca45a48a6008 --- /dev/null +++ b/example/engine.py @@ -0,0 +1,332 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +MatrixGame Engine + +This module handles the core rendering and model inference for the MatrixGame project. +""" + +import os +import logging +import argparse +import time +import torch +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from diffusers.utils import load_image +from diffusers.video_processor import VideoProcessor +from typing import Dict, List, Tuple, Any, Optional, Union + +# MatrixGame specific imports +from matrixgame.sample.pipeline_matrixgame import MatrixGameVideoPipeline +from matrixgame.model_variants import get_dit +from matrixgame.vae_variants import get_vae +from matrixgame.encoder_variants import get_text_enc +from matrixgame.model_variants.matrixgame_dit_src import MGVideoDiffusionTransformerI2V +from matrixgame.sample.flow_matching_scheduler_matrixgame import FlowMatchDiscreteScheduler +from teacache_forward import teacache_forward + +# Import utility functions +from utils import ( + visualize_controls, + frame_to_jpeg, + load_scene_frames, + logger +) + +class MatrixGameEngine: + """ + Core engine for MatrixGame model inference and frame generation. + """ + def __init__(self, args: Optional[argparse.Namespace] = None): + """ + Initialize the MatrixGame engine with configuration parameters. + + Args: + args: Optional parsed command line arguments for model configuration + """ + # Set default parameters if args not provided + self.frame_width = getattr(args, 'frame_width', 640) + self.frame_height = getattr(args, 'frame_height', 360) + self.fps = getattr(args, 'fps', 16) + self.inference_steps = getattr(args, 'inference_steps', 20) + self.guidance_scale = getattr(args, 'guidance_scale', 6.0) + self.num_pre_frames = getattr(args, 'num_pre_frames', 3) + + # Initialize state + self.frame_count = 0 + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.weight_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 + + # Model paths from environment or args + self.vae_path = os.environ.get("VAE_PATH", "./models/matrixgame/vae/") + self.dit_path = os.environ.get("DIT_PATH", "./models/matrixgame/dit/") + self.textenc_path = os.environ.get("TEXTENC_PATH", "./models/matrixgame") + + # Cache scene initial frames + self.scenes = { + 'forest': load_scene_frames('forest', self.frame_width, self.frame_height), + 'desert': load_scene_frames('desert', self.frame_width, self.frame_height), + 'beach': load_scene_frames('beach', self.frame_width, self.frame_height), + 'hills': load_scene_frames('hills', self.frame_width, self.frame_height), + 'river': load_scene_frames('river', self.frame_width, self.frame_height), + 'icy': load_scene_frames('icy', self.frame_width, self.frame_height), + 'mushroom': load_scene_frames('mushroom', self.frame_width, self.frame_height), + 'plain': load_scene_frames('plain', self.frame_width, self.frame_height) + } + + # Cache initial images for model input + self.scene_initial_images = {} + + # Initialize MatrixGame pipeline + self.model_loaded = False + if torch.cuda.is_available(): + try: + self._init_models() + self.model_loaded = True + logger.info("MatrixGame models loaded successfully") + except Exception as e: + logger.error(f"Failed to initialize MatrixGame models: {str(e)}") + logger.info("Falling back to frame cycling mode") + else: + logger.warning("CUDA not available. Using frame cycling mode only.") + + def _init_models(self): + """Initialize MatrixGame models (VAE, text encoder, transformer)""" + # Initialize flow matching scheduler + self.scheduler = FlowMatchDiscreteScheduler( + shift=15.0, + reverse=True, + solver="euler" + ) + + # Initialize VAE + try: + self.vae = get_vae("matrixgame", self.vae_path, self.weight_dtype) + self.vae.requires_grad_(False) + self.vae.eval() + self.vae.enable_tiling() + logger.info("VAE model loaded successfully") + except Exception as e: + logger.error(f"Error loading VAE model: {str(e)}") + raise + + # Initialize DIT (Transformer) + try: + dit = MGVideoDiffusionTransformerI2V.from_pretrained(self.dit_path) + dit.requires_grad_(False) + dit.eval() + logger.info("DIT model loaded successfully") + except Exception as e: + logger.error(f"Error loading DIT model: {str(e)}") + raise + + # Initialize text encoder + try: + self.text_enc = get_text_enc('matrixgame', self.textenc_path, weight_dtype=self.weight_dtype, i2v_type='refiner') + logger.info("Text encoder loaded successfully") + except Exception as e: + logger.error(f"Error loading text encoder: {str(e)}") + raise + + # Initialize pipeline + try: + self.pipeline = MatrixGameVideoPipeline( + vae=self.vae.vae, + text_encoder=self.text_enc, + transformer=dit, + scheduler=self.scheduler, + ).to(self.weight_dtype).to(self.device) + logger.info("Pipeline initialized successfully") + except Exception as e: + logger.error(f"Error initializing pipeline: {str(e)}") + raise + + # Configure teacache for the transformer + self.pipeline.transformer.__class__.enable_teacache = True + self.pipeline.transformer.__class__.cnt = 0 + self.pipeline.transformer.__class__.num_steps = self.inference_steps + self.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0 + self.pipeline.transformer.__class__.rel_l1_thresh = 0.075 + self.pipeline.transformer.__class__.previous_modulated_input = None + self.pipeline.transformer.__class__.previous_residual = None + self.pipeline.transformer.__class__.forward = teacache_forward + + # Preprocess initial images for all scenes + for scene_name, frames in self.scenes.items(): + if frames: + # Use first frame as initial image + self.scene_initial_images[scene_name] = self._preprocess_image(frames[0]) + + def _preprocess_image(self, image_array: np.ndarray) -> torch.Tensor: + """ + Preprocess an image for the model. + + Args: + image_array: Input image as numpy array + + Returns: + torch.Tensor: Preprocessed image tensor + """ + # Convert numpy array to PIL Image if needed + if isinstance(image_array, np.ndarray): + image = Image.fromarray(image_array) + else: + image = image_array + + # Preprocess for VAE + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, 'vae') else 8 + video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor) + initial_image = video_processor.preprocess(image, height=self.frame_height, width=self.frame_width) + + # Add past frames for stability (use same frame repeated) + past_frames = initial_image.repeat(self.num_pre_frames, 1, 1, 1) + initial_image = torch.cat([initial_image, past_frames], dim=0) + + return initial_image + + def generate_frame(self, scene_name: str, keyboard_condition: Optional[List] = None, + mouse_condition: Optional[List] = None) -> bytes: + """ + Generate the next frame based on current conditions using MatrixGame model. + + Args: + scene_name: Name of the current scene + keyboard_condition: Keyboard input state + mouse_condition: Mouse input state + + Returns: + bytes: JPEG bytes of the frame + """ + # Check if model is loaded + if not self.model_loaded or not torch.cuda.is_available(): + # Fall back to frame cycling for demo mode or if models failed to load + return self._fallback_frame(scene_name, keyboard_condition, mouse_condition) + else: + # Use MatrixGame model for frame generation + try: + # Get initial image for this scene + initial_image = self.scene_initial_images.get(scene_name) + if initial_image is None: + # Use forest as default if we don't have an initial image for this scene + initial_image = self.scene_initial_images.get('forest') + if initial_image is None: + # If we still don't have an initial image, fall back to frame cycling + logger.error(f"No initial image available for scene {scene_name}") + return self._fallback_frame(scene_name, keyboard_condition, mouse_condition) + + # Prepare input tensors (move to device and format correctly) + if keyboard_condition is None: + keyboard_condition = [[0, 0, 0, 0, 0, 0]] + if mouse_condition is None: + mouse_condition = [[0, 0]] + + # Convert conditions to tensors + keyboard_tensor = torch.tensor(keyboard_condition, dtype=torch.float32) + mouse_tensor = torch.tensor(mouse_condition, dtype=torch.float32) + + # Move to device and convert to correct dtype + keyboard_tensor = keyboard_tensor.to(self.weight_dtype).to(self.device) + mouse_tensor = mouse_tensor.to(self.weight_dtype).to(self.device) + + # Get the first frame from the scene for semantic conditioning + scene_frames = self.scenes.get(scene_name, self.scenes['forest']) + if not scene_frames: + return self._fallback_frame(scene_name, keyboard_condition, mouse_condition) + + semantic_image = Image.fromarray(scene_frames[0]) + + # Get PIL image version of the frame for visualization + for scene_frame in scene_frames: + if isinstance(scene_frame, np.ndarray): + semantic_image = Image.fromarray(scene_frame) + break + + # Generate a single frame with the model + # Use fewer inference steps for interactive frame generation + with torch.no_grad(): + # Generate a short video (we'll just use the first frame) + # We're using a short length (3 frames) for real-time performance + video = self.pipeline( + height=self.frame_height, + width=self.frame_width, + video_length=3, # Generate a very short video for speed + mouse_condition=mouse_tensor, + keyboard_condition=keyboard_tensor, + initial_image=initial_image, + num_inference_steps=self.inference_steps, + guidance_scale=self.guidance_scale, + embedded_guidance_scale=None, + data_type="video", + vae_ver='884-16c-hy', + enable_tiling=True, + generator=torch.Generator(device=self.device).manual_seed(42), + i2v_type='refiner', + semantic_images=semantic_image + ).videos[0] + + # Convert video tensor to numpy array (use first frame) + video_frame = video[0].permute(1, 2, 0).cpu().numpy() + video_frame = (video_frame * 255).astype(np.uint8) + frame = video_frame + + # Increment frame counter + self.frame_count += 1 + + except Exception as e: + logger.error(f"Error generating frame with MatrixGame model: {str(e)}") + # Fall back to cycling demo frames if model generation fails + return self._fallback_frame(scene_name, keyboard_condition, mouse_condition) + + # Add visualization of input controls + frame = visualize_controls( + frame, keyboard_condition, mouse_condition, + self.frame_width, self.frame_height + ) + + # Convert frame to JPEG + return frame_to_jpeg(frame, self.frame_height, self.frame_width) + + def _fallback_frame(self, scene_name: str, keyboard_condition: Optional[List] = None, + mouse_condition: Optional[List] = None) -> bytes: + """ + Generate a fallback frame when model generation fails. + + Args: + scene_name: Name of the current scene + keyboard_condition: Keyboard input state + mouse_condition: Mouse input state + + Returns: + bytes: JPEG bytes of the frame + """ + scene_frames = self.scenes.get(scene_name, self.scenes['forest']) + frame_idx = self.frame_count % len(scene_frames) + frame = scene_frames[frame_idx].copy() + self.frame_count += 1 + + # Add fallback mode indicator + cv2.putText(frame, "Fallback mode", + (10, self.frame_height - 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + + # Add visualization of input controls + frame = visualize_controls( + frame, keyboard_condition, mouse_condition, + self.frame_width, self.frame_height + ) + + # Convert frame to JPEG + return frame_to_jpeg(frame, self.frame_height, self.frame_width) + + def get_valid_scenes(self) -> List[str]: + """ + Get a list of valid scene names. + + Returns: + List[str]: List of valid scene names + """ + return list(self.scenes.keys()) \ No newline at end of file diff --git a/example/index.html b/example/index.html new file mode 100644 index 0000000000000000000000000000000000000000..c215658cef7b396ba2435527ea1705a097180f5d --- /dev/null +++ b/example/index.html @@ -0,0 +1,329 @@ + + + + + + MatrixGame Client + + + +
+
+
+ Game Frame +
Mouse: 0.00, 0.00
+
FPS: 0
+
+ +
+ + + + +
+
+ +
+ +
+
+
Keyboard Controls
+ +
+
+
+
+
W
+
+
+
A
+
S
+
D
+
+
+
SPACE
+
+
+
SHIFT
+
+
+

+ W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right
+ Space = Jump, Shift = Attack
+ Click on game view to capture mouse (ESC to release)
+ Mouse = Look around +

+
+
+ + +
+
+
Connection Log
+ +
+
+
+
Waiting to connect...
+
+
+
+
+
+ + + + \ No newline at end of file diff --git a/example/requirements.txt b/example/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f903e55a951ef2222ac45cd39ee12716e34c5142 --- /dev/null +++ b/example/requirements.txt @@ -0,0 +1,23 @@ +diffusers==0.32.2 +einops==0.8.1 + +#flash_attn==2.7.4.post1 +flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + +ftfy==6.3.1 +imageio==2.34.0 +numpy==1.24.4 +opencv_python==4.9.0.80 +opencv_python_headless==4.9.0.80 +packaging==25.0 +peft==0.14.0 +Pillow==11.2.1 +regex==2024.11.6 +safetensors==0.5.3 +torch==2.5.1 +torchvision==0.20.1 +torchaudio==2.5.1 +transformers==4.47.1 +aiohttp==3.9.3 +jinja2==3.1.3 +python-multipart==0.0.6 \ No newline at end of file diff --git a/example/server.py b/example/server.py new file mode 100644 index 0000000000000000000000000000000000000000..67b1703c3761987a855ad09c05f482d339d35495 --- /dev/null +++ b/example/server.py @@ -0,0 +1,649 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +MatrixGame Websocket Gaming Server + +This script implements a websocket server for the MatrixGame project, +allowing real-time streaming of game frames based on player inputs. +""" + +import asyncio +import json +import logging +import os +import pathlib +import time +import uuid +import base64 +import argparse +from typing import Dict, List, Any, Optional +from aiohttp import web, WSMsgType + +# Import the game engine +from engine import MatrixGameEngine +from utils import logger, parse_model_args, setup_gpu_environment + +class GameSession: + """ + Represents a user's gaming session. + Each WebSocket connection gets its own session with separate queues. + """ + def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager): + self.user_id = user_id + self.ws = ws + self.game_manager = game_manager + + # Create action queue for this user session + self.action_queue = asyncio.Queue() + + # Session creation time + self.created_at = time.time() + self.last_activity = time.time() + + # Game state + self.current_scene = "forest" # Default scene + self.is_streaming = False + self.stream_task = None + + # Current input state + self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack + self.mouse_state = [0, 0] # x, y + + self.background_tasks = [] + + async def start(self): + """Start all the queue processors for this session""" + self.background_tasks = [ + asyncio.create_task(self._process_action_queue()), + ] + logger.info(f"Started game session for user {self.user_id}") + + async def stop(self): + """Stop all background tasks for this session""" + # Stop streaming if active + if self.is_streaming and self.stream_task: + self.is_streaming = False + self.stream_task.cancel() + try: + await self.stream_task + except asyncio.CancelledError: + pass + + # Cancel other background tasks + for task in self.background_tasks: + task.cancel() + + try: + # Wait for tasks to complete cancellation + await asyncio.gather(*self.background_tasks, return_exceptions=True) + except asyncio.CancelledError: + pass + + logger.info(f"Stopped game session for user {self.user_id}") + + async def _process_action_queue(self): + """Process game actions from the queue""" + while True: + data = await self.action_queue.get() + try: + action_type = data.get('action') + + if action_type == 'start_stream': + result = await self._handle_start_stream(data) + elif action_type == 'stop_stream': + result = await self._handle_stop_stream(data) + elif action_type == 'keyboard_input': + result = await self._handle_keyboard_input(data) + elif action_type == 'mouse_input': + result = await self._handle_mouse_input(data) + elif action_type == 'change_scene': + result = await self._handle_scene_change(data) + else: + result = { + 'action': action_type, + 'requestId': data.get('requestId'), + 'success': False, + 'error': f'Unknown action: {action_type}' + } + + # Send response back to the client + await self.ws.send_json(result) + + # Update last activity time + self.last_activity = time.time() + + except Exception as e: + logger.error(f"Error processing action for user {self.user_id}: {str(e)}") + try: + await self.ws.send_json({ + 'action': data.get('action'), + 'requestId': data.get('requestId', 'unknown'), + 'success': False, + 'error': f'Error processing action: {str(e)}' + }) + except Exception as send_error: + logger.error(f"Error sending error response: {send_error}") + finally: + self.action_queue.task_done() + + async def _handle_start_stream(self, data: Dict) -> Dict: + """Handle request to start streaming frames""" + if self.is_streaming: + return { + 'action': 'start_stream', + 'requestId': data.get('requestId'), + 'success': False, + 'error': 'Stream already active' + } + + fps = data.get('fps', 16) + self.is_streaming = True + self.stream_task = asyncio.create_task(self._stream_frames(fps)) + + return { + 'action': 'start_stream', + 'requestId': data.get('requestId'), + 'success': True, + 'message': f'Streaming started at {fps} FPS' + } + + async def _handle_stop_stream(self, data: Dict) -> Dict: + """Handle request to stop streaming frames""" + if not self.is_streaming: + return { + 'action': 'stop_stream', + 'requestId': data.get('requestId'), + 'success': False, + 'error': 'No active stream to stop' + } + + self.is_streaming = False + if self.stream_task: + self.stream_task.cancel() + try: + await self.stream_task + except asyncio.CancelledError: + pass + self.stream_task = None + + return { + 'action': 'stop_stream', + 'requestId': data.get('requestId'), + 'success': True, + 'message': 'Streaming stopped' + } + + async def _handle_keyboard_input(self, data: Dict) -> Dict: + """Handle keyboard input from client""" + key = data.get('key', '') + pressed = data.get('pressed', False) + + # Map key to keyboard state index + key_map = { + 'w': 0, 'forward': 0, + 's': 1, 'back': 1, 'backward': 1, + 'a': 2, 'left': 2, + 'd': 3, 'right': 3, + 'space': 4, 'jump': 4, + 'shift': 5, 'attack': 5, 'ctrl': 5 + } + + if key.lower() in key_map: + key_idx = key_map[key.lower()] + self.keyboard_state[key_idx] = 1 if pressed else 0 + + return { + 'action': 'keyboard_input', + 'requestId': data.get('requestId'), + 'success': True, + 'keyboardState': self.keyboard_state + } + + async def _handle_mouse_input(self, data: Dict) -> Dict: + """Handle mouse movement/input from client""" + mouse_x = data.get('x', 0) + mouse_y = data.get('y', 0) + + # Update mouse state, normalize values between -1 and 1 + self.mouse_state = [float(mouse_x), float(mouse_y)] + + return { + 'action': 'mouse_input', + 'requestId': data.get('requestId'), + 'success': True, + 'mouseState': self.mouse_state + } + + async def _handle_scene_change(self, data: Dict) -> Dict: + """Handle scene change requests""" + scene_name = data.get('scene', 'forest') + valid_scenes = self.game_manager.valid_scenes + + if scene_name not in valid_scenes: + return { + 'action': 'change_scene', + 'requestId': data.get('requestId'), + 'success': False, + 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}' + } + + self.current_scene = scene_name + + return { + 'action': 'change_scene', + 'requestId': data.get('requestId'), + 'success': True, + 'scene': scene_name + } + + async def _stream_frames(self, fps: int): + """Stream frames to the client at the specified FPS""" + frame_interval = 1.0 / fps # Time between frames in seconds + + try: + while self.is_streaming: + start_time = time.time() + + # Generate frame based on current keyboard and mouse state + keyboard_condition = [self.keyboard_state] + mouse_condition = [self.mouse_state] + + # Use the engine to generate the next frame + frame_bytes = self.game_manager.engine.generate_frame( + self.current_scene, keyboard_condition, mouse_condition + ) + + # Encode as base64 for sending in JSON + frame_base64 = base64.b64encode(frame_bytes).decode('utf-8') + + # Send frame to client + await self.ws.send_json({ + 'action': 'frame', + 'frameData': frame_base64, + 'timestamp': time.time() + }) + + # Calculate sleep time to maintain FPS + elapsed = time.time() - start_time + sleep_time = max(0, frame_interval - elapsed) + await asyncio.sleep(sleep_time) + + except asyncio.CancelledError: + logger.info(f"Frame streaming cancelled for user {self.user_id}") + except Exception as e: + logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}") + if self.ws.closed: + logger.info(f"WebSocket closed for user {self.user_id}") + return + + # Notify client of error + try: + await self.ws.send_json({ + 'action': 'frame_error', + 'error': f'Streaming error: {str(e)}' + }) + except: + pass + + # Stop streaming + self.is_streaming = False + +class GameManager: + """ + Manages all active gaming sessions and shared resources. + """ + def __init__(self, args: argparse.Namespace): + self.sessions = {} + self.session_lock = asyncio.Lock() + + # Initialize game engine + self.engine = MatrixGameEngine(args) + + # Load valid scenes from engine + self.valid_scenes = self.engine.get_valid_scenes() + + async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession: + """Create a new game session""" + async with self.session_lock: + # Create a new session for this user + session = GameSession(user_id, ws, self) + await session.start() + self.sessions[user_id] = session + return session + + async def delete_session(self, user_id: str) -> None: + """Delete a game session and clean up resources""" + async with self.session_lock: + if user_id in self.sessions: + session = self.sessions[user_id] + await session.stop() + del self.sessions[user_id] + logger.info(f"Deleted game session for user {user_id}") + + def get_session(self, user_id: str) -> Optional[GameSession]: + """Get a game session if it exists""" + return self.sessions.get(user_id) + + async def close_all_sessions(self) -> None: + """Close all active sessions (used during shutdown)""" + async with self.session_lock: + for user_id, session in list(self.sessions.items()): + await session.stop() + self.sessions.clear() + logger.info("Closed all active game sessions") + + @property + def session_count(self) -> int: + """Get the number of active sessions""" + return len(self.sessions) + + def get_session_stats(self) -> Dict: + """Get statistics about active sessions""" + stats = { + 'total_sessions': len(self.sessions), + 'active_scenes': {}, + 'streaming_sessions': 0 + } + + # Count sessions by scene and streaming status + for session in self.sessions.values(): + scene = session.current_scene + stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1 + if session.is_streaming: + stats['streaming_sessions'] += 1 + + return stats + +# Create global game manager +game_manager = None + +async def status_handler(request: web.Request) -> web.Response: + """Handler for API status endpoint""" + # Get session statistics + session_stats = game_manager.get_session_stats() + + return web.json_response({ + 'product': 'MatrixGame WebSocket Server', + 'version': '1.0.0', + 'active_sessions': session_stats, + 'available_scenes': game_manager.valid_scenes + }) + +async def root_handler(request: web.Request) -> web.Response: + """Handler for serving the client at the root path""" + client_path = pathlib.Path(__file__).parent / 'client' / 'index.html' + + with open(client_path, 'r') as file: + html_content = file.read() + + return web.Response(text=html_content, content_type='text/html') + +async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + """Handle WebSocket connections with robust error handling""" + logger.info(f"WebSocket connection attempt - PATH: {request.path}, QUERY: {request.query_string}") + + # Log request headers at debug level only (could contain sensitive information) + logger.debug(f"WebSocket request headers: {dict(request.headers)}") + + # Prepare a WebSocket response with appropriate settings + ws = web.WebSocketResponse( + max_msg_size=1024*1024*10, # 10MB max message size + timeout=60.0, + heartbeat=30.0 # Add heartbeat to keep connection alive + ) + + # Check if WebSocket protocol is supported + if not ws.can_prepare(request): + logger.error("Cannot prepare WebSocket: WebSocket protocol not supported") + return web.Response(status=400, text="WebSocket protocol not supported") + + try: + logger.info("Preparing WebSocket connection...") + await ws.prepare(request) + + # Generate a unique user ID for this connection + user_id = str(uuid.uuid4()) + + # Get client IP address + peername = request.transport.get_extra_info('peername') + if peername is not None: + client_ip = peername[0] + else: + client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip() + + # Log connection success + logger.info(f"Client {user_id} connecting from IP: {client_ip} - WebSocket connection established") + + # Mark that the session is established + is_session_created = False + + try: + # Store the user ID in the websocket for easy access + ws.user_id = user_id + + # Create a new session for this user + logger.info(f"Creating game session for user {user_id}") + user_session = await game_manager.create_session(user_id, ws) + is_session_created = True + logger.info(f"Game session created for user {user_id}") + except Exception as session_error: + logger.error(f"Error creating game session: {str(session_error)}", exc_info=True) + if not ws.closed: + await ws.close(code=1011, message=f"Server error: {str(session_error)}".encode()) + if is_session_created: + await game_manager.delete_session(user_id) + return ws + except Exception as e: + logger.error(f"Error establishing WebSocket connection: {str(e)}", exc_info=True) + if not ws.closed and ws.prepared: + await ws.close(code=1011, message=f"Server error: {str(e)}".encode()) + return ws + + # Send initial welcome message + try: + await ws.send_json({ + 'action': 'welcome', + 'userId': user_id, + 'message': 'Welcome to the MatrixGame WebSocket server!', + 'scenes': game_manager.valid_scenes + }) + logger.info(f"Sent welcome message to user {user_id}") + except Exception as welcome_error: + logger.error(f"Error sending welcome message: {str(welcome_error)}") + if not ws.closed: + await ws.close(code=1011, message=b"Failed to send welcome message") + await game_manager.delete_session(user_id) + return ws + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + try: + data = json.loads(msg.data) + action = data.get('action') + + logger.debug(f"Received {action} message from user {user_id}") + + if action == 'ping': + # Respond to ping immediately + await ws.send_json({ + 'action': 'pong', + 'requestId': data.get('requestId'), + 'timestamp': time.time() + }) + else: + # Route game actions to the session's action queue + await user_session.action_queue.put(data) + + except json.JSONDecodeError: + logger.error(f"Invalid JSON from user {user_id}: {msg.data}") + if not ws.closed: + await ws.send_json({ + 'error': 'Invalid JSON message', + 'success': False + }) + except Exception as e: + logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}") + if not ws.closed: + await ws.send_json({ + 'action': data.get('action') if 'data' in locals() else 'unknown', + 'success': False, + 'error': f'Error processing message: {str(e)}' + }) + + elif msg.type == WSMsgType.ERROR: + logger.error(f"WebSocket error for user {user_id}: {ws.exception()}") + break + + elif msg.type == WSMsgType.CLOSE: + logger.info(f"WebSocket close received for user {user_id} (code: {msg.data}, message: {msg.extra})") + break + + elif msg.type == WSMsgType.CLOSING: + logger.info(f"WebSocket closing for user {user_id}") + break + + elif msg.type == WSMsgType.CLOSED: + logger.info(f"WebSocket already closed for user {user_id}") + break + + except Exception as ws_error: + logger.error(f"Unexpected WebSocket error for user {user_id}: {str(ws_error)}", exc_info=True) + finally: + # Cleanup session + try: + logger.info(f"Cleaning up session for user {user_id}") + await game_manager.delete_session(user_id) + logger.info(f"Connection closed for user {user_id}") + except Exception as cleanup_error: + logger.error(f"Error during session cleanup for user {user_id}: {str(cleanup_error)}") + + return ws + +async def init_app(args, base_path="") -> web.Application: + """Initialize the web application""" + global game_manager + + # Initialize game manager with command line args + game_manager = GameManager(args) + + app = web.Application( + client_max_size=1024**2*10 # 10MB max size + ) + + # Add cleanup logic + async def cleanup(app): + logger.info("Shutting down server, closing all sessions...") + await game_manager.close_all_sessions() + + app.on_shutdown.append(cleanup) + + # Add routes with CORS headers for WebSockets + # Configure CORS for all routes + @web.middleware + async def cors_middleware(request, handler): + if request.method == 'OPTIONS': + # Handle preflight requests + resp = web.Response() + resp.headers['Access-Control-Allow-Origin'] = '*' + resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' + resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With' + return resp + + # Normal request, call the handler + resp = await handler(request) + + # Add CORS headers to the response + resp.headers['Access-Control-Allow-Origin'] = '*' + resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' + resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With' + return resp + + app.middlewares.append(cors_middleware) + + # Add a debug endpoint to help diagnose WebSocket issues + async def debug_handler(request): + client_ip = request.remote + headers = dict(request.headers) + server_host = request.host + + debug_info = { + "client_ip": client_ip, + "server_host": server_host, + "headers": headers, + "request_path": request.path, + "server_time": time.time(), + "base_path": base_path, + "websocket_route": f"{base_path}/ws", + "all_routes": [route.name for route in app.router.routes() if route.name], + "server_info": { + "active_sessions": game_manager.session_count, + "available_scenes": game_manager.valid_scenes + } + } + + return web.json_response(debug_info) + + # Set up routes with the base_path + # Add multiple WebSocket routes to ensure compatibility + logger.info(f"Setting up WebSocket route at {base_path}/ws") + app.router.add_get(f'{base_path}/ws', websocket_handler, name='ws_handler') + + # Also add WebSocket route at the root for Hugging Face compatibility + if base_path: + logger.info(f"Adding additional WebSocket route at /ws") + app.router.add_get('/ws', websocket_handler, name='ws_root_handler') + + # Add routes for API and debug endpoints + app.router.add_get(f'{base_path}/api/status', status_handler, name='status_handler') + app.router.add_get(f'{base_path}/api/debug', debug_handler, name='debug_handler') + + # Serve the client at both the base path and root path for compatibility + app.router.add_get(f'{base_path}/', root_handler, name='root_handler') + + # Always serve at the root path for Hugging Face Spaces compatibility + if base_path: + app.router.add_get('/', root_handler, name='root_handler_no_base') + + # Set up static file serving for the client assets + app.router.add_static(f'{base_path}/assets', pathlib.Path(__file__).parent / 'client', name='static_handler') + + # Add static file serving at root for compatibility + if base_path: + app.router.add_static('/assets', pathlib.Path(__file__).parent / 'client', name='static_handler_no_base') + + return app + +def parse_args() -> argparse.Namespace: + """Parse server-specific command line arguments""" + parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to") + parser.add_argument("--port", type=int, default=8080, help="Port to listen on") + parser.add_argument("--path", type=str, default="", help="Base path for the server (for proxy setups)") + + # Parse server args first + server_args, remaining_args = parser.parse_known_args() + + # Parse model args and combine + model_args = parse_model_args() + + # Combine all args + combined_args = argparse.Namespace(**vars(server_args), **vars(model_args)) + + return combined_args + +if __name__ == '__main__': + # Configure GPU environment + setup_gpu_environment() + + # Parse command line arguments + args = parse_args() + + # Initialize app + loop = asyncio.get_event_loop() + app = loop.run_until_complete(init_app(args, base_path=args.path)) + + # Start server + logger.info(f"Starting MatrixGame WebSocket Server at {args.host}:{args.port}") + web.run_app(app, host=args.host, port=args.port) \ No newline at end of file diff --git a/example/utils.py b/example/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c35aa64e6d4f4cc1452b3c4789143ce2be6ef7 --- /dev/null +++ b/example/utils.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +MatrixGame Utility Functions + +This module contains helper functions and utilities for the MatrixGame project. +""" + +import os +import logging +import argparse +import torch +import numpy as np +import cv2 +from PIL import Image +from typing import Dict, List, Tuple, Any, Optional, Union + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def setup_gpu_environment(): + """ + Configure the GPU environment and log GPU information. + + Returns: + bool: True if CUDA is available, False otherwise + """ + # Set CUDA memory allocation environment variable for better performance + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + + # Check if CUDA is available and log information + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + gpu_info = [] + + for i in range(gpu_count): + gpu_name = torch.cuda.get_device_name(i) + gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3) # Convert to GB + gpu_info.append(f"GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)") + + logger.info(f"CUDA is available. Found {gpu_count} GPU(s):") + for info in gpu_info: + logger.info(f" {info}") + return True + else: + logger.warning("CUDA is not available. Running in CPU-only mode.") + return False + +def parse_model_args() -> argparse.Namespace: + """ + Parse command line arguments for model paths and configuration. + + Returns: + argparse.Namespace: Parsed arguments + """ + parser = argparse.ArgumentParser(description="MatrixGame Model Configuration") + + # Model paths + parser.add_argument("--model_root", type=str, default="./models/matrixgame", + help="Root directory for model files") + parser.add_argument("--dit_path", type=str, default=None, + help="Path to DIT model. If not provided, will use MODEL_ROOT/dit/") + parser.add_argument("--vae_path", type=str, default=None, + help="Path to VAE model. If not provided, will use MODEL_ROOT/vae/") + parser.add_argument("--textenc_path", type=str, default=None, + help="Path to text encoder model. If not provided, will use MODEL_ROOT") + + # Model settings + parser.add_argument("--inference_steps", type=int, default=20, + help="Number of inference steps for frame generation (lower is faster)") + parser.add_argument("--guidance_scale", type=float, default=6.0, + help="Guidance scale for generation") + parser.add_argument("--frame_width", type=int, default=640, + help="Width of the generated frames") + parser.add_argument("--frame_height", type=int, default=360, + help="Height of the generated frames") + parser.add_argument("--num_pre_frames", type=int, default=3, + help="Number of pre-frames for conditioning") + parser.add_argument("--fps", type=int, default=16, + help="Frames per second for video") + + args = parser.parse_args() + + # Set environment variables for model paths if provided + if args.model_root: + os.environ.setdefault("MODEL_ROOT", args.model_root) + if args.dit_path: + os.environ.setdefault("DIT_PATH", args.dit_path) + else: + os.environ.setdefault("DIT_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "dit/")) + if args.vae_path: + os.environ.setdefault("VAE_PATH", args.vae_path) + else: + os.environ.setdefault("VAE_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "vae/")) + if args.textenc_path: + os.environ.setdefault("TEXTENC_PATH", args.textenc_path) + else: + os.environ.setdefault("TEXTENC_PATH", os.environ.get("MODEL_ROOT", "./models/matrixgame")) + + return args + +def visualize_controls(frame: np.ndarray, keyboard_condition: List, mouse_condition: List, + frame_width: int, frame_height: int) -> np.ndarray: + """ + Visualize keyboard and mouse controls on the frame. + + Args: + frame: The video frame to visualize on + keyboard_condition: Keyboard state as a list + mouse_condition: Mouse state as a list + frame_width: Width of the frame + frame_height: Height of the frame + + Returns: + np.ndarray: Frame with visualized controls + """ + # Clone the frame to avoid modifying the original + frame = frame.copy() + + # If we have keyboard/mouse conditions, visualize them on the frame + if keyboard_condition: + # Visualize keyboard inputs + keys = ["W", "S", "A", "D", "JUMP", "ATTACK"] + for i, key_pressed in enumerate(keyboard_condition[0]): + color = (0, 255, 0) if key_pressed else (100, 100, 100) + cv2.putText(frame, keys[i], (20 + i*100, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) + + if mouse_condition: + # Visualize mouse movement + mouse_x, mouse_y = mouse_condition[0] + # Scale mouse values for visualization + offset_x = int(mouse_x * 100) + offset_y = int(mouse_y * 100) + center_x, center_y = frame_width // 2, frame_height // 2 + cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1) + cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}", + (frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) + + return frame + +def frame_to_jpeg(frame: np.ndarray, frame_height: int, frame_width: int) -> bytes: + """ + Convert a frame to JPEG bytes. + + Args: + frame: The video frame to convert + frame_height: Height of the frame for fallback + frame_width: Width of the frame for fallback + + Returns: + bytes: JPEG bytes of the frame + """ + success, buffer = cv2.imencode('.jpg', frame) + if not success: + logger.error("Failed to encode frame as JPEG") + # Return a blank frame + blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 100 + success, buffer = cv2.imencode('.jpg', blank) + + return buffer.tobytes() + +def load_scene_frames(scene_name: str, frame_width: int, frame_height: int) -> List[np.ndarray]: + """ + Load initial frames for a scene from asset directory. + + Args: + scene_name: Name of the scene + frame_width: Width to resize frames to + frame_height: Height to resize frames to + + Returns: + List[np.ndarray]: List of frames as numpy arrays + """ + frames = [] + scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}" + + if os.path.exists(scene_dir): + image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')]) + for img_file in image_files: + try: + img_path = os.path.join(scene_dir, img_file) + img = Image.open(img_path).convert("RGB") + img = img.resize((frame_width, frame_height)) + frames.append(np.array(img)) + except Exception as e: + logger.error(f"Error loading image {img_file}: {str(e)}") + + # If no frames were loaded, create a default colored frame with text + if not frames: + frame = np.ones((frame_height, frame_height, 3), dtype=np.uint8) * 100 + # Add scene name as text + cv2.putText(frame, f"Scene: {scene_name}", (50, 180), + cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) + frames.append(frame) + + return frames \ No newline at end of file diff --git a/game/spawn/1/act.npy b/game/spawn/1/act.npy new file mode 100644 index 0000000000000000000000000000000000000000..2924767c1a504eb1e15ce0ac28d39624e61fe7d0 --- /dev/null +++ b/game/spawn/1/act.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7caabf75f45d4c8bae5c0b66dc2b5a3cbf3ab7dbf89521d6ba539c4f30048d75 +size 10688 diff --git a/game/spawn/1/full_res.npy b/game/spawn/1/full_res.npy new file mode 100644 index 0000000000000000000000000000000000000000..4e6c6c2ff258a8d56c070de131ed13d63968d5d7 --- /dev/null +++ b/game/spawn/1/full_res.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c186e607a8cc4922e17ee66c1f37dc4858adfef220b6bf48fbcda9bf75ffde34 +size 22260128 diff --git a/game/spawn/1/low_res.npy b/game/spawn/1/low_res.npy new file mode 100644 index 0000000000000000000000000000000000000000..e7b92ab83886e42ab0fc6d3c4146d34a846093a0 --- /dev/null +++ b/game/spawn/1/low_res.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:08e31f7ef447a1dffda3b51b4102ae4301f5804145b63c71acc5c278a294b1ee +size 368768 diff --git a/game/spawn/1/next_act.npy b/game/spawn/1/next_act.npy new file mode 100644 index 0000000000000000000000000000000000000000..7c6afd9bd0075b320e33f95e49cad10748b53dbe --- /dev/null +++ b/game/spawn/1/next_act.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0f2d1c96337459ddd84b9c1a8dbad9eb7284cb813a2f0ebd9cb4d757dd294e1 +size 105728 diff --git a/game/spawn/2/act.npy b/game/spawn/2/act.npy new file mode 100644 index 0000000000000000000000000000000000000000..4745e120d72bd68c4de9b05c7b31ae9b6ba17894 --- /dev/null +++ b/game/spawn/2/act.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:619cdda3de7f55e48a753b64542354387294c2688a6e014af85b094996b8a486 +size 10688 diff --git a/game/spawn/2/full_res.npy b/game/spawn/2/full_res.npy new file mode 100644 index 0000000000000000000000000000000000000000..7d5357363ae4968b166615ce8bde52240a94fcde --- /dev/null +++ b/game/spawn/2/full_res.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:911389a5957acd8d7b96fccb1917ee4a4b0c74f0b9420f61580b571965dd99ff +size 22260128 diff --git a/game/spawn/2/low_res.npy b/game/spawn/2/low_res.npy new file mode 100644 index 0000000000000000000000000000000000000000..dc811a357c4143af270640a78f2a1a0be290bf2c --- /dev/null +++ b/game/spawn/2/low_res.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8351be59118b4c2237800a9119937d472564e63c70a8d1911bc2d52ac3a95a2 +size 368768 diff --git a/game/spawn/2/next_act.npy b/game/spawn/2/next_act.npy new file mode 100644 index 0000000000000000000000000000000000000000..d1a89ff69cc6d2aabd893a9c67cd58045b4fe61f --- /dev/null +++ b/game/spawn/2/next_act.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba1b85fd4548f805352306a04e24368718328bfd513cb95305d0f9284fe9719f +size 105728 diff --git a/game/spawn/3/act.npy b/game/spawn/3/act.npy new file mode 100644 index 0000000000000000000000000000000000000000..795e24fb86302d7e4f38ef66d3a3ae87bd513fc6 --- /dev/null +++ b/game/spawn/3/act.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60ff3ff40a48105e33feae08979dd4d7d7570984e10c950aae9308c08841400a +size 10688 diff --git a/game/spawn/3/full_res.npy b/game/spawn/3/full_res.npy new file mode 100644 index 0000000000000000000000000000000000000000..71f10372b77761a15339fb5f1758a10e59dbf86f --- /dev/null +++ b/game/spawn/3/full_res.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1f8a9a266e05f8d8a741dadb697bf3ed89c1a166dab443c1d81091c3cf1824b +size 22260128 diff --git a/game/spawn/3/low_res.npy b/game/spawn/3/low_res.npy new file mode 100644 index 0000000000000000000000000000000000000000..de23d6a00891c2f79014c4981f6e2c1b528d1d19 --- /dev/null +++ b/game/spawn/3/low_res.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0dd822892eb5201e23d5298da69523a4d43d740e37d7647d30a288a5b440991e +size 368768 diff --git a/game/spawn/3/next_act.npy b/game/spawn/3/next_act.npy new file mode 100644 index 0000000000000000000000000000000000000000..17f461d04b5278bad039cee29e2c1966a20b5528 --- /dev/null +++ b/game/spawn/3/next_act.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40f2a5d1381a93fc46e58f7ac0bf5df9f507c5c3b6489e85a0896deba8da9dca +size 105728 diff --git a/index.html b/index.html new file mode 100644 index 0000000000000000000000000000000000000000..90e4922b2881ce5d0f18ee2e4adad4ec10f623d3 --- /dev/null +++ b/index.html @@ -0,0 +1,928 @@ + + + + + + AI Game Multiverse + + + +
+

AI Game Multiverse

+

Play procedurally generated games using AI

+
+ +
+
+
+ Game Frame +
Mouse: 0.00, 0.00
+
FPS: 0
+
+ +
+ + + + +
+
+ +
+ +
+
+
Keyboard Controls
+ +
+
+
+
+
W
+
+
+
A
+
S
+
D
+
+
+
SPACE
+
+
+
SHIFT
+
+
+

+ W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right
+ Space = Jump, Shift = Attack
+ Click on game view to capture mouse (ESC to release)
+ Mouse = Look around +

+
+
+ + +
+
+
Connection Log
+ +
+
+
+
Welcome to AI Game Multiverse. Click Connect to begin.
+
+
+
+
+
+ + + + \ No newline at end of file diff --git a/reference_example/Dockerfile b/reference_example/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..427f95c1602875b9f737a838afd29ae250e84831 --- /dev/null +++ b/reference_example/Dockerfile @@ -0,0 +1,52 @@ +FROM nvidia/cuda:12.4.0-devel-ubuntu22.04 + +ARG DEBIAN_FRONTEND=noninteractive + +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update && apt-get install --no-install-recommends -y \ + build-essential \ + python3.11 \ + python3-pip \ + python3-dev \ + git \ + curl \ + ffmpeg \ + libglib2.0-0 \ + libsm6 \ + libxrender1 \ + libxext6 \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +WORKDIR /code + +COPY ./requirements.txt /code/requirements.txt + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user +# Switch to the "user" user +USER user +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +# Set home to the user's home directory +ENV PYTHONPATH=$HOME/app \ + PYTHONUNBUFFERED=1 \ + DATA_ROOT=/tmp/data + +RUN echo "Installing requirements.txt" +RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt + +# yeah.. this is manual for now +#RUN flutter build web + +WORKDIR $HOME/app + +COPY --chown=user . $HOME/app + +EXPOSE 8080 + +ENV PORT 8080 + +CMD python3 api.py diff --git a/reference_example/api.py b/reference_example/api.py new file mode 100644 index 0000000000000000000000000000000000000000..5be3e44d40ab1e88c37f36332d9349e9ddde884b --- /dev/null +++ b/reference_example/api.py @@ -0,0 +1,297 @@ +import asyncio +import json +import logging +import os +import pathlib +import time +import uuid +from aiohttp import web, WSMsgType +from typing import Dict, Any + +from api_core import VideoGenerationAPI +from api_session import SessionManager +from api_metrics import MetricsTracker +from api_config import * + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# Create global session and metrics managers +session_manager = SessionManager() +metrics_tracker = MetricsTracker() + +# Dictionary to track connected anonymous clients by IP address +anon_connections = {} +anon_connection_lock = asyncio.Lock() + +async def status_handler(request: web.Request) -> web.Response: + """Handler for API status endpoint""" + api = session_manager.shared_api + + # Get current busy status of all endpoints + endpoint_statuses = [] + for ep in api.endpoint_manager.endpoints: + endpoint_statuses.append({ + 'id': ep.id, + 'url': ep.url, + 'busy': ep.busy, + 'last_used': ep.last_used, + 'error_count': ep.error_count, + 'error_until': ep.error_until + }) + + # Get session statistics + session_stats = session_manager.get_session_stats() + + # Get metrics + api_metrics = metrics_tracker.get_metrics() + + return web.json_response({ + 'product': PRODUCT_NAME, + 'version': PRODUCT_VERSION, + 'maintenance_mode': MAINTENANCE_MODE, + 'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS), + 'endpoint_status': endpoint_statuses, + 'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())), + 'active_sessions': session_stats, + 'metrics': api_metrics + }) + +async def metrics_handler(request: web.Request) -> web.Response: + """Handler for detailed metrics endpoint (protected)""" + # Check for API key in header or query param + auth_header = request.headers.get('Authorization', '') + api_key = None + + if auth_header.startswith('Bearer '): + api_key = auth_header[7:] + else: + api_key = request.query.get('key') + + # Validate API key (using SECRET_TOKEN as the API key) + if not api_key or api_key != SECRET_TOKEN: + return web.json_response({ + 'error': 'Unauthorized' + }, status=401) + + # Get detailed metrics + detailed_metrics = metrics_tracker.get_detailed_metrics() + + return web.json_response(detailed_metrics) + +async def websocket_handler(request: web.Request) -> web.WebSocketResponse: + # Check if maintenance mode is enabled + if MAINTENANCE_MODE: + # Return an error response indicating maintenance mode + return web.json_response({ + 'error': 'Server is in maintenance mode', + 'maintenance': True + }, status=503) # 503 Service Unavailable + + ws = web.WebSocketResponse( + max_msg_size=1024*1024*20, # 20MB max message size + timeout=30.0 # we want to keep things tight and short + ) + + await ws.prepare(request) + + # Get the Hugging Face token from query parameters + hf_token = request.query.get('hf_token', '') + + # Generate a unique user ID for this connection + user_id = str(uuid.uuid4()) + + # Validate the token and determine the user role + user_role = await session_manager.shared_api.validate_user_token(hf_token) + logger.info(f"User {user_id} connected with role: {user_role}") + + # Get client IP address + peername = request.transport.get_extra_info('peername') + if peername is not None: + client_ip = peername[0] + else: + client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip() + + logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}") + + # Check for anonymous user connection limits + if user_role == 'anon': + async with anon_connection_lock: + # Track this connection + anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1 + # Store the IP so we can clean up later + ws.client_ip = client_ip + + # Log multiple connections from same IP but don't restrict them + if anon_connections[client_ip] > 1: + logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections") + + # Store the user role in the websocket for easy access + ws.user_role = user_role + ws.user_id = user_id + + # Register with metrics + metrics_tracker.register_session(user_id, client_ip) + + # Create a new session for this user + user_session = await session_manager.create_session(user_id, user_role, ws) + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + try: + data = json.loads(msg.data) + action = data.get('action') + + # Check for rate limiting + request_type = 'other' + if action in ['join_chat', 'leave_chat', 'chat_message']: + request_type = 'chat' + elif action in ['generate_video']: + request_type = 'video' + elif action == 'search': + request_type = 'search' + elif action == 'simulate': + request_type = 'simulation' + + # Record the request for metrics + await metrics_tracker.record_request(user_id, client_ip, request_type, user_role) + + # Check rate limits (except for admins) + if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role): + await ws.send_json({ + 'action': action, + 'requestId': data.get('requestId'), + 'success': False, + 'error': f'Rate limit exceeded for {request_type} requests. Please try again later.' + }) + continue + + # Route requests to appropriate queues + if action in ['join_chat', 'leave_chat', 'chat_message']: + await user_session.chat_queue.put(data) + elif action in ['generate_video']: + await user_session.video_queue.put(data) + elif action == 'search': + await user_session.search_queue.put(data) + elif action == 'simulate': + await user_session.simulation_queue.put(data) + else: + await user_session.process_generic_request(data) + + except Exception as e: + logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}") + await ws.send_json({ + 'action': data.get('action') if 'data' in locals() else 'unknown', + 'success': False, + 'error': f'Error processing message: {str(e)}' + }) + + elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): + break + + finally: + # Cleanup session + await session_manager.delete_session(user_id) + + # Cleanup anonymous connection tracking + if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'): + client_ip = ws.client_ip + async with anon_connection_lock: + if client_ip in anon_connections: + anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1) + if anon_connections[client_ip] == 0: + del anon_connections[client_ip] + logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}") + + # Unregister from metrics + metrics_tracker.unregister_session(user_id, client_ip) + logger.info(f"Connection closed for user {user_id}") + + return ws + +async def init_app() -> web.Application: + app = web.Application( + client_max_size=1024**2*20 # 20MB max size + ) + + # Add cleanup logic + async def cleanup(app): + logger.info("Shutting down server, closing all sessions...") + await session_manager.close_all_sessions() + + app.on_shutdown.append(cleanup) + + # Add routes + app.router.add_get('/ws', websocket_handler) + app.router.add_get('/api/status', status_handler) + app.router.add_get('/api/metrics', metrics_handler) + + # Set up static file serving + # Define the path to the public directory + public_path = pathlib.Path(__file__).parent / 'build' / 'web' + if not public_path.exists(): + public_path.mkdir(parents=True, exist_ok=True) + + # Set up static file serving with proper security considerations + async def static_file_handler(request): + # Get the path from the request (removing leading /) + path_parts = request.path.lstrip('/').split('/') + + # Convert to safe path to prevent path traversal attacks + safe_path = public_path.joinpath(*path_parts) + + # Make sure the path is within the public directory (prevent directory traversal) + try: + safe_path = safe_path.resolve() + if not str(safe_path).startswith(str(public_path.resolve())): + return web.HTTPForbidden(text="Access denied") + except (ValueError, FileNotFoundError): + return web.HTTPNotFound() + + # If path is a directory, look for index.html + if safe_path.is_dir(): + safe_path = safe_path / 'index.html' + + # Check if the file exists + if not safe_path.exists() or not safe_path.is_file(): + # If not found, serve index.html (for SPA routing) + safe_path = public_path / 'index.html' + if not safe_path.exists(): + return web.HTTPNotFound() + + # Determine content type based on file extension + content_type = 'text/plain' + ext = safe_path.suffix.lower() + if ext == '.html': + content_type = 'text/html' + elif ext == '.js': + content_type = 'application/javascript' + elif ext == '.css': + content_type = 'text/css' + elif ext in ('.jpg', '.jpeg'): + content_type = 'image/jpeg' + elif ext == '.png': + content_type = 'image/png' + elif ext == '.gif': + content_type = 'image/gif' + elif ext == '.svg': + content_type = 'image/svg+xml' + elif ext == '.json': + content_type = 'application/json' + + # Return the file with appropriate headers + return web.FileResponse(safe_path, headers={'Content-Type': content_type}) + + # Add catch-all route for static files (lower priority than API routes) + app.router.add_get('/{path:.*}', static_file_handler) + + return app + +if __name__ == '__main__': + app = asyncio.run(init_app()) + web.run_app(app, host='0.0.0.0', port=8080) \ No newline at end of file diff --git a/reference_example/api_config.py b/reference_example/api_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b366493d35c59e610b47f8db2b84238e99b1771e --- /dev/null +++ b/reference_example/api_config.py @@ -0,0 +1,184 @@ +import os + +PRODUCT_NAME = os.environ.get('PRODUCT_NAME', 'TikSlop') +PRODUCT_VERSION = "2.0.0" + +# you should use Mistral 7b instruct for good performance and accuracy balance +TEXT_MODEL = os.environ.get('HF_TEXT_MODEL', '') + +# Environment variable to control maintenance mode +MAINTENANCE_MODE = os.environ.get('MAINTENANCE_MODE', 'false').lower() in ('true', 'yes', '1', 't') + +# Environment variable to control how many nodes to use +MAX_NODES = int(os.environ.get('MAX_NODES', '8')) + +ADMIN_ACCOUNTS = [ + "jbilcke-hf" +] + +RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS = [ + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_1', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_2', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_3', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_4', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_5', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_6', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_7', ''), + os.environ.get('VIDEO_ROUND_ROBIN_SERVER_8', ''), +] + +# Filter out empty strings from the endpoint list +filtered_urls = [url for url in RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS if url] + +# Limit the number of URLs based on MAX_NODES environment variable +VIDEO_ROUND_ROBIN_ENDPOINT_URLS = filtered_urls[:MAX_NODES] + +HF_TOKEN = os.environ.get('HF_TOKEN') + +# use the same secret token as you used to secure your BASE_SPACE_NAME spaces +SECRET_TOKEN = os.environ.get('SECRET_TOKEN') + +# altenative words we could use: "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres" +NEGATIVE_PROMPT = "low quality, worst quality, deformed, distorted, disfigured, blurry, text, watermark" + +POSITIVE_PROMPT_SUFFIX = "high quality, cinematic, 4K, intricate details" + +GUIDANCE_SCALE = 1.0 + +THUMBNAIL_FRAMES = 65 + +# anonymous users are people browing TikSlop without being connected +# this category suffers from regular abuse so we need to enforce strict limitations +CONFIG_FOR_ANONYMOUS_USERS = { + + # anons can only watch 2 minutes per video + "max_rendering_time_per_client_per_video_in_sec": 2 * 60, + + "min_num_inference_steps": 2, + "default_num_inference_steps": 4, + "max_num_inference_steps": 4, + + "min_num_frames": 9, # 8 + 1 + "default_max_num_frames": 65, # 8*8 + 1 + "max_num_frames": 65, # 8*8 + 1 + + "min_clip_duration_seconds": 1, + "default_clip_duration_seconds": 2, + "max_clip_duration_seconds": 2, + + "min_clip_playback_speed": 0.7, + "default_clip_playback_speed": 0.7, + "max_clip_playback_speed": 0.7, + + "min_clip_framerate": 8, + "default_clip_framerate": 16, + "max_clip_framerate": 16, + + "min_clip_width": 544, + "default_clip_width": 640, + "max_clip_width": 640, + + "min_clip_height": 320, + "default_clip_height": 352, + "max_clip_height": 352, +} + +# Hugging Face users enjoy a more normal and calibrated experience +CONFIG_FOR_STANDARD_HF_USERS = { + "max_rendering_time_per_client_per_video_in_sec": 15 * 60, + + "min_num_inference_steps": 2, + "default_num_inference_steps": 4, + "max_num_inference_steps": 4, + + "min_num_frames": 9, # 8 + 1 + "default_num_frames": 81, # 8*10 + 1 + "max_num_frames": 81, + + "min_clip_duration_seconds": 1, + "default_clip_duration_seconds": 3, + "max_clip_duration_seconds": 3, + + "min_clip_playback_speed": 0.7, + "default_clip_playback_speed": 0.7, + "max_clip_playback_speed": 0.7, + + "min_clip_framerate": 8, + "default_clip_framerate": 25, + "max_clip_framerate": 25, + + "min_clip_width": 544, + "default_clip_width": 1152, # 928, # 1216, # 768, # 640, + "max_clip_width": 1152, # 928, # 1216, # 768, # 640, + + "min_clip_height": 320, + "default_clip_height": 640, # 512, # 448, # 416, + "max_clip_height": 640, # 512, # 448, # 416, +} + +# Hugging Face users with a Pro may enjoy an improved experience +CONFIG_FOR_PRO_HF_USERS = { + "max_rendering_time_per_client_per_video_in_sec": 20 * 60, + + "min_num_inference_steps": 2, + "default_num_inference_steps": 4, + "max_num_inference_steps": 4, + + "min_num_frames": 9, # 8 + 1 + "default_num_frames": 81, # 8*10 + 1 + "max_num_frames": 81, + + "min_clip_duration_seconds": 1, + "default_clip_duration_seconds": 3, + "max_clip_duration_seconds": 3, + + "min_clip_playback_speed": 0.7, + "default_clip_playback_speed": 0.7, + "max_clip_playback_speed": 0.7, + + "min_clip_framerate": 8, + "default_clip_framerate": 25, + "max_clip_framerate": 25, + + "min_clip_width": 544, + "default_clip_width": 1152, # 928, # 1216, # 768, # 640, + "max_clip_width": 1152, # 928, # 1216, # 768, # 640, + + "min_clip_height": 320, + "default_clip_height": 640, # 512, # 448, # 416, + "max_clip_height": 640, # 512, # 448, # 416, +} + +CONFIG_FOR_ADMIN_HF_USERS = { + "max_rendering_time_per_client_per_video_in_sec": 60 * 60, + + "min_num_inference_steps": 2, + "default_num_inference_steps": 4, + "max_num_inference_steps": 4, + + "min_num_frames": 9, # 8 + 1 + "default_num_frames": 81, # (8 * 10) + 1 + "max_num_frames": 129, # (8 * 16) + 1 + + "min_clip_duration_seconds": 1, + "default_clip_duration_seconds": 2, + "max_clip_duration_seconds": 4, + + "min_clip_playback_speed": 0.7, + "default_clip_playback_speed": 0.7, + "max_clip_playback_speed": 1.0, + + "min_clip_framerate": 8, + "default_clip_framerate": 30, + "max_clip_framerate": 60, + + "min_clip_width": 544, + "default_clip_width": 1152, # 928, # 1216, # 768, # 640, + "max_clip_width": 1152, # 928, # 1216, # 768, # 640, + + "min_clip_height": 320, + "default_clip_height": 640, # 512, # 448, # 416, + "max_clip_height": 640, # 512, # 448, # 416, +} + +CONFIG_FOR_ADMIN_HF_USERS = CONFIG_FOR_PRO_HF_USERS \ No newline at end of file diff --git a/reference_example/api_core.py b/reference_example/api_core.py new file mode 100644 index 0000000000000000000000000000000000000000..8066d154d3d041cc08b3f231ff9b0378b9b5424a --- /dev/null +++ b/reference_example/api_core.py @@ -0,0 +1,1068 @@ +import logging +import os +import io +import re +import base64 +import uuid +from typing import Dict, Any, Optional, List, Literal +from dataclasses import dataclass +from asyncio import Lock, Queue +import asyncio +import time +import datetime +from contextlib import asynccontextmanager +from collections import defaultdict +from aiohttp import web, ClientSession +from huggingface_hub import InferenceClient, HfApi +from gradio_client import Client +import random +import yaml +import json + +from api_config import * + +# User role type +UserRole = Literal['anon', 'normal', 'pro', 'admin'] + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +def generate_seed(): + """Generate a random positive 32-bit integer seed.""" + return random.randint(0, 2**32 - 1) + +def sanitize_yaml_response(response_text: str) -> str: + """ + Sanitize and format AI response into valid YAML. + Returns properly formatted YAML string. + """ + + response_text = response_text.split("```")[0] + + # Remove any markdown code block indicators and YAML document markers + clean_text = re.sub(r'```yaml|```|---|\.\.\.$', '', response_text.strip()) + + # Split into lines and process each line + lines = clean_text.split('\n') + sanitized_lines = [] + current_field = None + + for line in lines: + stripped = line.strip() + if not stripped: + continue + + # Handle field starts + if stripped.startswith('title:') or stripped.startswith('description:'): + # Ensure proper YAML format with space after colon and proper quoting + field_name = stripped.split(':', 1)[0] + field_value = stripped.split(':', 1)[1].strip().strip('"\'') + + # Quote the value if it contains special characters + if any(c in field_value for c in ':[]{},&*#?|-<>=!%@`'): + field_value = f'"{field_value}"' + + sanitized_lines.append(f"{field_name}: {field_value}") + current_field = field_name + + elif stripped.startswith('tags:'): + sanitized_lines.append('tags:') + current_field = 'tags' + + elif stripped.startswith('-') and current_field == 'tags': + # Process tag values + tag = stripped[1:].strip().strip('"\'') + if tag: + # Clean and format tag + tag = re.sub(r'[^\x00-\x7F]+', '', tag) # Remove non-ASCII + tag = re.sub(r'[^a-zA-Z0-9\s-]', '', tag) # Keep only alphanumeric and hyphen + tag = tag.strip().lower().replace(' ', '-') + if tag: + sanitized_lines.append(f" - {tag}") + + elif current_field in ['title', 'description']: + # Handle multi-line title/description continuation + value = stripped.strip('"\'') + if value: + # Append to previous line + prev = sanitized_lines[-1] + sanitized_lines[-1] = f"{prev} {value}" + + # Ensure the YAML has all required fields + required_fields = {'title', 'description', 'tags'} + found_fields = {line.split(':')[0].strip() for line in sanitized_lines if ':' in line} + + for field in required_fields - found_fields: + if field == 'tags': + sanitized_lines.extend(['tags:', ' - default']) + else: + sanitized_lines.append(f'{field}: "No {field} provided"') + + return '\n'.join(sanitized_lines) + +@dataclass +class Endpoint: + id: int + url: str + busy: bool = False + last_used: float = 0 + error_count: int = 0 + error_until: float = 0 # Timestamp until which this endpoint is considered in error state + +class EndpointManager: + def __init__(self): + self.endpoints: List[Endpoint] = [] + self.lock = Lock() + self.initialize_endpoints() + self.last_used_index = -1 # Track the last used endpoint for round-robin + + def initialize_endpoints(self): + """Initialize the list of endpoints""" + for i, url in enumerate(VIDEO_ROUND_ROBIN_ENDPOINT_URLS): + endpoint = Endpoint(id=i + 1, url=url) + self.endpoints.append(endpoint) + + def _get_next_free_endpoint(self): + """Get the next available non-busy endpoint, or oldest endpoint if all are busy""" + current_time = time.time() + + # First priority: Get any non-busy and non-error endpoint + free_endpoints = [ + ep for ep in self.endpoints + if not ep.busy and current_time > ep.error_until + ] + + if free_endpoints: + # Return the least recently used free endpoint + return min(free_endpoints, key=lambda ep: ep.last_used) + + # Second priority: If all busy/error, use round-robin but skip error endpoints + tried_count = 0 + next_index = self.last_used_index + + while tried_count < len(self.endpoints): + next_index = (next_index + 1) % len(self.endpoints) + tried_count += 1 + + # If endpoint is not in error state, use it + if current_time > self.endpoints[next_index].error_until: + self.last_used_index = next_index + return self.endpoints[next_index] + + # If all endpoints are in error state, use the one with earliest error expiry + self.last_used_index = next_index + return min(self.endpoints, key=lambda ep: ep.error_until) + + @asynccontextmanager + async def get_endpoint(self, max_wait_time: int = 10): + """Get the next available endpoint using a context manager""" + start_time = time.time() + endpoint = None + + try: + while True: + if time.time() - start_time > max_wait_time: + raise TimeoutError(f"Could not acquire an endpoint within {max_wait_time} seconds") + + async with self.lock: + # Get the next available endpoint using our selection strategy + endpoint = self._get_next_free_endpoint() + + # Mark it as busy + endpoint.busy = True + endpoint.last_used = time.time() + #logger.info(f"Using endpoint {endpoint.id} (busy: {endpoint.busy}, last used: {endpoint.last_used})") + break + + yield endpoint + + finally: + if endpoint: + async with self.lock: + endpoint.busy = False + endpoint.last_used = time.time() + # We don't need to put back into queue - our strategy now picks directly from the list + +class ChatRoom: + def __init__(self): + self.messages = [] + self.connected_clients = set() + self.max_history = 100 + + def add_message(self, message): + self.messages.append(message) + if len(self.messages) > self.max_history: + self.messages.pop(0) + + def get_recent_messages(self, limit=50): + return self.messages[-limit:] + +class VideoGenerationAPI: + def __init__(self): + self.inference_client = InferenceClient(token=HF_TOKEN) + self.hf_api = HfApi(token=HF_TOKEN) + self.endpoint_manager = EndpointManager() + self.active_requests: Dict[str, asyncio.Future] = {} + self.chat_rooms = defaultdict(ChatRoom) + self.video_events: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + self.event_history_limit = 50 + # Cache for user roles to avoid repeated API calls + self.user_role_cache: Dict[str, Dict[str, Any]] = {} + # Cache expiration time (10 minutes) + self.cache_expiration = 600 + + + def _add_event(self, video_id: str, event: Dict[str, Any]): + """Add an event to the video's history and maintain the size limit""" + events = self.video_events[video_id] + events.append(event) + if len(events) > self.event_history_limit: + events.pop(0) + + async def validate_user_token(self, token: str) -> UserRole: + """ + Validates a Hugging Face token and determines the user's role. + + Returns one of: + - 'anon': Anonymous user (no token or invalid token) + - 'normal': Standard Hugging Face user + - 'pro': Hugging Face Pro user + - 'admin': Admin user (username in ADMIN_ACCOUNTS) + """ + # If no token is provided, the user is anonymous + if not token: + return 'anon' + + # Check if we have a cached result for this token + current_time = time.time() + if token in self.user_role_cache: + cached_data = self.user_role_cache[token] + # If the cache is still valid + if current_time - cached_data['timestamp'] < self.cache_expiration: + logger.info(f"Using cached user role: {cached_data['role']}") + return cached_data['role'] + + # No valid cache, need to check the token with the HF API + try: + # Use HF API to validate the token and get user info + logger.info("Validating Hugging Face token...") + + # Run in executor to avoid blocking the event loop + user_info = await asyncio.get_event_loop().run_in_executor( + None, + lambda: self.hf_api.whoami(token=token) + ) + + # Handle both object and dict response formats from whoami + username = user_info.get('name') if isinstance(user_info, dict) else getattr(user_info, 'name', None) + is_pro = user_info.get('is_pro') if isinstance(user_info, dict) else getattr(user_info, 'is_pro', False) + + if not username: + logger.error(f"Could not determine username from user_info: {user_info}") + return 'anon' + + logger.info(f"Token valid for user: {username}") + + # Determine the user role based on the information + user_role: UserRole + + # Check if the user is an admin + if username in ADMIN_ACCOUNTS: + user_role = 'admin' + # Check if the user has a pro account + elif is_pro: + user_role = 'pro' + else: + user_role = 'normal' + + # Cache the result + self.user_role_cache[token] = { + 'role': user_role, + 'timestamp': current_time, + 'username': username + } + + return user_role + + except Exception as e: + logger.error(f"Failed to validate Hugging Face token: {str(e)}") + # If validation fails, the user is treated as anonymous + return 'anon' + + async def download_video(self, url: str) -> bytes: + """Download video file from URL and return bytes""" + async with ClientSession() as session: + async with session.get(url) as response: + if response.status != 200: + raise Exception(f"Failed to download video: HTTP {response.status}") + return await response.read() + + async def search_video(self, query: str, attempt_count: int = 0) -> Optional[dict]: + """Generate a single search result using HF text generation""" + # Maximum number of attempts to generate a description without placeholder tags + max_attempts = 2 + current_attempt = attempt_count + # Use a random temperature between 0.68 and 0.72 to generate more diverse results + # and prevent duplicate results from successive calls with the same prompt + temperature = random.uniform(0.68, 0.72) + + while current_attempt <= max_attempts: + prompt = f"""# Instruction +Your response MUST be a YAML object containing a title and description, consistent with what we can find on a video sharing platform. +Format your YAML response with only those fields: "title" (a short string) and "description" (string caption of the scene). Do not add any other field. +In the description field, describe in a very synthetic way the visuals of the first shot (first scene), eg "