jbilcke-hf HF Staff commited on
Commit
260ff53
·
verified ·
1 Parent(s): 8130c8c

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.0.1-runtime-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONDONTWRITEBYTECODE=1
5
+ ENV PYTHONUNBUFFERED=1
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ python3 \
11
+ python3-pip \
12
+ python3-dev \
13
+ ffmpeg \
14
+ libsm6 \
15
+ libxext6 \
16
+ libxrender-dev \
17
+ libglib2.0-0 \
18
+ git \
19
+ && apt-get clean \
20
+ && rm -rf /var/lib/apt/lists/*
21
+
22
+ # Set working directory
23
+ WORKDIR /app
24
+
25
+ # Copy requirements first to leverage Docker caching
26
+ COPY requirements.txt .
27
+
28
+ # Install Python dependencies
29
+ RUN pip3 install --no-cache-dir -r requirements.txt
30
+ RUN pip3 install aiohttp
31
+
32
+ # Install additional required packages
33
+ RUN pip3 install --no-cache-dir torch torchvision torchaudio
34
+
35
+ # Copy application code
36
+ COPY . .
37
+
38
+ # Create assets directory if it doesn't exist
39
+ RUN mkdir -p /app/assets
40
+
41
+ # Expose the port used by the server
42
+ EXPOSE 8080
43
+
44
+ # Set entry command
45
+ CMD ["python3", "server.py", "--host", "0.0.0.0", "--port", "8080"]
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Eloi Alonso
4
+ Copyright (c) 2025 Enigma Labs AI
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,170 @@
1
  ---
2
- title: Tikslop Gaming Multiverse
3
- emoji: 🏃
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: docker
7
- pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Multiverse
3
+ emoji: 🐟
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: docker
7
+ app_file: server.py
8
+ pinned: true
9
+ short_description: AI Multiplayer World Model
10
+ app_port: 8080
11
+ disable_embedding: false
12
  ---
13
 
14
+ # Multiverse: The First AI Multiplayer World Model
15
+
16
+ 🌐 [Enigma-AI website](https://enigma-labs.io/) - 📚 [Technical Blog](https://enigma-labs.io/) - [🤗 Model on Huggingface](https://huggingface.co/Enigma-AI/multiverse) - [🤗 Datasets on Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res) - 𝕏 [Multiverse Tweet](https://x.com/j0nathanj/status/1920516649511244258)
17
+
18
+ <div align='center'>
19
+ <b>Two human players driving cars in Multiverse</b>
20
+ <br>
21
+ <img alt="Cars in Multiverse" src="assets/demo.gif" width="400">
22
+ </div>
23
+
24
+ ---
25
+
26
+ ## Installation
27
+ ```bash
28
+ git clone https://github.com/EnigmaLabsAI/multiverse
29
+ cd multiverse
30
+ pip install -r requirements.txt
31
+ ```
32
+
33
+ ### Running the model
34
+
35
+ ```bash
36
+ python src/play.py --compile
37
+ ```
38
+
39
+ > Note on Apple Silicon you must enable CPU fallback for MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1 python src/play.py
40
+
41
+ When running this command, you will be prompted with the controls. Press `enter` to start:
42
+ ![img.png](assets/img.png)
43
+
44
+ Then the game will be start:
45
+ * To control the silver car at the top screen use the arrow keys.
46
+ * To control the blue car at the bottom use the WASD keys.
47
+
48
+ ![img_2.png](assets/img_2.png)
49
+
50
+ ---
51
+
52
+
53
+ ## Training
54
+
55
+ Multiverse comprised two models:
56
+ * Denoiser - a world model that simulates a game
57
+ * Upsampler - a model which takes the frames from the denoiser and increases their resolution
58
+
59
+ ### Denoiser training
60
+
61
+ #### 1. Download the dataset
62
+ Download the Denoiser's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res).
63
+
64
+ #### 2. Process data for training
65
+ Run the command:
66
+ ```bash
67
+ python src/process_denoiser_files.py <folder_with_dataset_files_from_step_one> <folder_to_store_processed_data>
68
+ ```
69
+
70
+ #### 3. Edit training configuration
71
+
72
+ Edit [config/env/racing.yaml](config/env/racing.yaml) and set:
73
+ - `path_data_low_res` to `<folder_to_store_processed_data>/low_res`
74
+ - `path_data_full_res` to `<folder_to_store_processed_data>/full_res`
75
+
76
+ Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`:
77
+ ```yaml
78
+ train_model: denoiser
79
+ ```
80
+
81
+ #### 4. Launch training run
82
+
83
+ You can then launch a training run with `python src/main.py`.
84
+
85
+
86
+ ### Upsampler training
87
+
88
+ #### 1. Download the dataset
89
+ Download the Upsampler's training set from [🤗 Huggingface](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res).
90
+
91
+ #### 2. Process data for training
92
+ Run the command:
93
+ ```bash
94
+ python src/process_upsampler_files.py <folder_with_dataset_files_from_step_one> <folder_to_store_processed_data>
95
+ ```
96
+
97
+ #### 3. Edit training configuration
98
+
99
+ Edit [config/env/racing.yaml](config/env/racing.yaml) and set:
100
+ - `path_data_low_res` to `<folder_to_store_processed_data>/low_res`
101
+ - `path_data_full_res` to `<folder_to_store_processed_data>/full_res`
102
+
103
+ Edit [config/training.yaml](config/trainer.yaml) to train the `denoiser`:
104
+ ```yaml
105
+ train_model: upsampler
106
+ ```
107
+
108
+ #### 4. Launch training run
109
+
110
+ You can then launch a training run with `python src/main.py`.
111
+
112
+
113
+ ---
114
+
115
+ ## Datasets
116
+
117
+ 1. We've collected over 4 hours of multiplayer (1v1) footage from Gran Turismo 4 at a resolution of 48x64 (per players): [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-low-res).
118
+
119
+ 2. A sparse sampling of full resolution, cropped frames, are availabe in order to train the upsampler at a resolution of 350x530: [🤗 Huggingface link](https://huggingface.co/datasets/Enigma-AI/multiplayer-racing-full-res).
120
+
121
+ The datasets contain a variety of situations: acceleration, braking, overtakes, crashes, and expert driving for both players.
122
+ You can read about the data collection mechanism [here](https://enigma-labs.io/blog)
123
+
124
+ Note: The full resolution dataset is only for upsampler training and is not fit for world model training.
125
+
126
+ ---
127
+
128
+ ## Outside resources
129
+
130
+ - DIAMOND - https://github.com/eloialonso/diamond
131
+ - AI-MarioKart64 - https://github.com/Dere-Wah/AI-MarioKart64
132
+
133
+ ---
134
+
135
+ ## Cloud Gaming Server
136
+
137
+ This project now includes a WebSocket-based cloud gaming server that allows you to play the game through a web browser.
138
+
139
+ ### Using Docker (Recommended for GPU Servers)
140
+
141
+ The easiest way to deploy the cloud gaming server on a machine with an NVIDIA GPU is using Docker:
142
+
143
+ ```bash
144
+ # Build the Docker image
145
+ docker build -t ai-game-multiverse .
146
+
147
+ # Run the container with GPU support
148
+ docker run --gpus all -p 8080:8080 ai-game-multiverse
149
+ ```
150
+
151
+ Then access the web interface at http://yourserver:8080
152
+
153
+ ### Features
154
+
155
+ - Web-based interface accessible from any modern browser
156
+ - Real-time streaming of AI-generated game frames
157
+ - Keyboard and mouse controls
158
+ - Multiple scene selection
159
+ - WebSocket communication for low-latency interaction
160
+
161
+ ### Usage
162
+
163
+ 1. Access the web interface at http://yourserver:8080
164
+ 2. Click "Connect" to establish a WebSocket connection
165
+ 3. Select a scene from the dropdown
166
+ 4. Click "Start Stream" to begin streaming frames
167
+ 5. Use WASD keys for movement, Space for jump, Shift for attack
168
+ 6. Mouse controls camera view (click on game area to capture mouse)
169
+
170
+ Note: The server requires an NVIDIA GPU for optimal performance with the AI models. Without a suitable GPU, it will fall back to using simple placeholder frames.
config/agent/racing.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: agent.AgentConfig
2
+
3
+ denoiser:
4
+ _target_: models.diffusion.DenoiserConfig
5
+ sigma_data: 0.5
6
+ sigma_offset_noise: 0.1
7
+ noise_previous_obs: true
8
+ upsampling_factor: null
9
+ frame_sampling:
10
+ - count: 4
11
+ stride: 1
12
+ - count: 4
13
+ stride: 4
14
+ inner_model:
15
+ _target_: models.diffusion.InnerModelConfig
16
+ img_channels: 6
17
+ num_steps_conditioning: 8
18
+ cond_channels: 2048
19
+ depths:
20
+ - 2
21
+ - 2
22
+ - 2
23
+ - 2
24
+ channels:
25
+ - 128
26
+ - 256
27
+ - 512
28
+ - 1024
29
+ attn_depths:
30
+ - 0
31
+ - 0
32
+ - 1
33
+ - 1
34
+
35
+ upsampler:
36
+ _target_: models.diffusion.DenoiserConfig
37
+ sigma_data: 0.5
38
+ sigma_offset_noise: 0.1
39
+ noise_previous_obs: false
40
+ upsampling_factor: 10
41
+ upsampling_frame_height: 350
42
+ upsampling_frame_width: 530
43
+ inner_model:
44
+ _target_: models.diffusion.InnerModelConfig
45
+ img_channels: 6
46
+ num_steps_conditioning: 0
47
+ cond_channels: 2048
48
+ depths:
49
+ - 2
50
+ - 2
51
+ - 2
52
+ - 2
53
+ channels:
54
+ - 64
55
+ - 64
56
+ - 128
57
+ - 256
58
+ attn_depths:
59
+ - 0
60
+ - 0
61
+ - 0
62
+ - 0
63
+
64
+ rew_end_model: null
65
+
66
+ actor_critic: null
config/env/racing.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ train:
2
+ id: racing
3
+ size: [700, 530]
4
+ num_actions: 66
5
+ path_data_low_res: null
6
+ path_data_full_res: null
7
+ keymap: racing
config/trainer.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - env: racing
4
+ - agent: racing
5
+ - world_model_env: fast
6
+
7
+ hydra:
8
+ job:
9
+ chdir: True
10
+
11
+ wandb:
12
+ mode: offline
13
+ project: null
14
+ entity: null
15
+ name: null
16
+ group: null
17
+ tags: null
18
+
19
+ initialization:
20
+ path_to_ckpt: null
21
+ load_denoiser: True
22
+ load_rew_end_model: True
23
+ load_actor_critic: True
24
+
25
+ common:
26
+ devices: all # int, list of int, cpu, or all
27
+ seed: null
28
+ resume: False # do not modify, set by scripts/resume.sh only.
29
+
30
+ checkpointing:
31
+ save_agent_every: 5
32
+ num_to_keep: 11 # number of checkpoints to keep, use null to disable
33
+
34
+ collection:
35
+ train:
36
+ num_envs: 1
37
+ epsilon: 0.01
38
+ num_steps_total: 100000
39
+ first_epoch:
40
+ min: 5000
41
+ max: 10000 # null: no maximum
42
+ threshold_rew: 10
43
+ steps_per_epoch: 100
44
+ test:
45
+ num_envs: 1
46
+ num_episodes: 4
47
+ epsilon: 0.0
48
+ num_final_episodes: 100
49
+
50
+ static_dataset:
51
+ path: ${env.path_data_low_res}
52
+ ignore_sample_weights: True
53
+
54
+ training:
55
+ should: True
56
+ num_final_epochs: 600
57
+ cache_in_ram: False
58
+ num_workers_data_loaders: 1
59
+ model_free: False # if True, turn off world_model training and RL in imagination
60
+ compile_wm: False
61
+
62
+ evaluation:
63
+ should: True
64
+ every: 20
65
+
66
+ train_model: denoiser
67
+
68
+ denoiser:
69
+ training:
70
+ num_autoregressive_steps: 8
71
+ initial_num_consecutive_page_count: 1
72
+ num_consecutive_pages:
73
+ - epoch: 400
74
+ count: 10
75
+ - epoch: 500
76
+ count: 50
77
+ start_after_epochs: 0
78
+ steps_first_epoch: 10
79
+ steps_per_epoch: 20
80
+ sample_weights: null
81
+ batch_size: 30
82
+ grad_acc_steps: 2
83
+ lr_warmup_steps: 100
84
+ max_grad_norm: 10.0
85
+
86
+ optimizer:
87
+ lr: 1e-4
88
+ weight_decay: 1e-2
89
+ eps: 1e-8
90
+
91
+ sigma_distribution: # log normal distribution for sigma during training
92
+ _target_: models.diffusion.SigmaDistributionConfig
93
+ loc: -1.2
94
+ scale: 1.2
95
+ sigma_min: 2e-3
96
+ sigma_max: 20
97
+
98
+ upsampler:
99
+ training:
100
+ num_autoregressive_steps: 1
101
+ initial_num_consecutive_page_count: 1
102
+ start_after_epochs: 0
103
+ steps_first_epoch: 20
104
+ steps_per_epoch: 20
105
+ sample_weights: null
106
+ batch_size: 4
107
+ grad_acc_steps: 2
108
+ lr_warmup_steps: 100
109
+ max_grad_norm: 10.0
110
+
111
+ optimizer: ${denoiser.optimizer}
112
+ sigma_distribution: ${denoiser.sigma_distribution}
113
+
config/world_model_env/fast.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: envs.WorldModelEnvConfig
2
+ horizon: 1000
3
+ num_batches_to_preload: 256
4
+ diffusion_sampler_next_obs:
5
+ _target_: models.diffusion.DiffusionSamplerConfig
6
+ num_steps_denoising: 1
7
+ sigma_min: 2e-3
8
+ sigma_max: 5.0
9
+ rho: 7
10
+ order: 1 # 1: Euler, 2: Heun
11
+ s_churn: 0.0 # Amount of stochasticity
12
+ s_tmin: 0.0
13
+ s_tmax: ${eval:'float("inf")'}
14
+ s_noise: 1.0
15
+ s_cond: 0.005
16
+ diffusion_sampler_upsampling:
17
+ _target_: models.diffusion.DiffusionSamplerConfig
18
+ num_steps_denoising: 1
19
+ sigma_min: 1
20
+ sigma_max: 5.0
21
+ rho: 7
22
+ order: 2 # 1: Euler, 2: Heun
23
+ s_churn: 10.0 # Amount of stochasticity
24
+ s_tmin: 1
25
+ s_tmax: 5
26
+ s_noise: 0.9
27
+ s_cond: 0
example/Dockerfile ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.11 \
10
+ python3-pip \
11
+ python3-dev \
12
+ git \
13
+ curl \
14
+ ffmpeg \
15
+ libglib2.0-0 \
16
+ libsm6 \
17
+ libxrender1 \
18
+ libxext6 \
19
+ ninja-build \
20
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
21
+
22
+ WORKDIR /code
23
+
24
+ COPY ./requirements.txt /code/requirements.txt
25
+
26
+ # Set up a new user named "user" with user ID 1000
27
+ RUN useradd -m -u 1000 user
28
+ # Switch to the "user" user
29
+ USER user
30
+ # Set home to the user's home directory
31
+ ENV HOME=/home/user \
32
+ PATH=/home/user/.local/bin:$PATH
33
+
34
+ # Set Python path and environment variables
35
+ ENV PYTHONPATH=$HOME/app \
36
+ PYTHONUNBUFFERED=1 \
37
+ DATA_ROOT=/tmp/data
38
+
39
+ RUN echo "Installing requirements.txt"
40
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
41
+
42
+ # Install NVIDIA Apex with CUDA and C++ extensions
43
+ RUN cd $HOME && \
44
+ git clone https://github.com/NVIDIA/apex && \
45
+ cd apex && \
46
+ NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--parallel" --global-option="8" ./
47
+
48
+ WORKDIR $HOME/app
49
+
50
+ # Copy all files and set proper ownership
51
+ COPY --chown=user . $HOME/app
52
+
53
+ # Expose the port that server.py uses (8080)
54
+ EXPOSE 8080
55
+
56
+ ENV PORT 8080
57
+
58
+ # Run the HF space launcher script which sets up the correct paths
59
+ CMD ["python3", "run_hf_space.py"]
example/client.js ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // MatrixGame WebSocket Client
2
+
3
+ // WebSocket connection
4
+ let socket = null;
5
+ let userId = null;
6
+ let isStreaming = false;
7
+ let lastFrameTime = 0;
8
+ let frameCount = 0;
9
+ let fpsUpdateInterval = null;
10
+
11
+ // DOM Elements
12
+ const connectBtn = document.getElementById('connect-btn');
13
+ const startStreamBtn = document.getElementById('start-stream-btn');
14
+ const stopStreamBtn = document.getElementById('stop-stream-btn');
15
+ const sceneSelect = document.getElementById('scene-select');
16
+ const gameCanvas = document.getElementById('game-canvas');
17
+ const connectionLog = document.getElementById('connection-log');
18
+ const mousePosition = document.getElementById('mouse-position');
19
+ const fpsCounter = document.getElementById('fps-counter');
20
+ const mouseTrackingArea = document.getElementById('mouse-tracking-area');
21
+
22
+ // Pointer Lock API support check
23
+ const pointerLockSupported = 'pointerLockElement' in document ||
24
+ 'mozPointerLockElement' in document ||
25
+ 'webkitPointerLockElement' in document;
26
+
27
+ // Keyboard DOM elements
28
+ const keyElements = {
29
+ 'w': document.getElementById('key-w'),
30
+ 'a': document.getElementById('key-a'),
31
+ 's': document.getElementById('key-s'),
32
+ 'd': document.getElementById('key-d'),
33
+ 'space': document.getElementById('key-space'),
34
+ 'shift': document.getElementById('key-shift')
35
+ };
36
+
37
+ // Key mapping to action names
38
+ const keyToAction = {
39
+ 'w': 'forward',
40
+ 'arrowup': 'forward',
41
+ 'a': 'left',
42
+ 'arrowleft': 'left',
43
+ 's': 'back',
44
+ 'arrowdown': 'back',
45
+ 'd': 'right',
46
+ 'arrowright': 'right',
47
+ ' ': 'jump',
48
+ 'shift': 'attack'
49
+ };
50
+
51
+ // Key state tracking
52
+ const keyState = {
53
+ 'forward': false,
54
+ 'back': false,
55
+ 'left': false,
56
+ 'right': false,
57
+ 'jump': false,
58
+ 'attack': false
59
+ };
60
+
61
+ // Mouse state
62
+ const mouseState = {
63
+ x: 0,
64
+ y: 0,
65
+ captured: false
66
+ };
67
+
68
+ // Test server connectivity before establishing WebSocket
69
+ async function testServerConnectivity() {
70
+ try {
71
+ // Get base path by extracting path from the script tag's src attribute
72
+ let basePath = '';
73
+ const scriptTags = document.getElementsByTagName('script');
74
+ for (const script of scriptTags) {
75
+ if (script.src.includes('client.js')) {
76
+ const url = new URL(script.src);
77
+ basePath = url.pathname.replace('/assets/client.js', '');
78
+ break;
79
+ }
80
+ }
81
+
82
+ // Try to fetch the debug endpoint to see if the server is accessible
83
+ const response = await fetch(`${window.location.protocol}//${window.location.host}${basePath}/api/debug`);
84
+ if (!response.ok) {
85
+ throw new Error(`Server returned ${response.status}`);
86
+ }
87
+
88
+ const debugInfo = await response.json();
89
+ logMessage(`Server connection test successful! Server time: ${new Date(debugInfo.server_time * 1000).toLocaleTimeString()}`);
90
+
91
+ // Log available routes from server
92
+ if (debugInfo.all_routes && debugInfo.all_routes.length > 0) {
93
+ logMessage(`Available routes: ${debugInfo.all_routes.join(', ')}`);
94
+ }
95
+
96
+ // Return the debug info for connection setup
97
+ return debugInfo;
98
+ } catch (error) {
99
+ logMessage(`Server connection test failed: ${error.message}`);
100
+ return null;
101
+ }
102
+ }
103
+
104
+ // Connect to WebSocket server
105
+ async function connectWebSocket() {
106
+ // First test connectivity to the server
107
+ logMessage('Testing server connectivity...');
108
+ const debugInfo = await testServerConnectivity();
109
+
110
+ // Use secure WebSocket (wss://) if the page is loaded over HTTPS
111
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
112
+
113
+ // Get base path by extracting path from the script tag's src attribute
114
+ let basePath = '';
115
+ if (debugInfo && debugInfo.base_path) {
116
+ // Use base path from server if available
117
+ basePath = debugInfo.base_path;
118
+ logMessage(`Using server-provided base path: ${basePath}`);
119
+ } else {
120
+ const scriptTags = document.getElementsByTagName('script');
121
+ for (const script of scriptTags) {
122
+ if (script.src.includes('client.js')) {
123
+ const url = new URL(script.src);
124
+ basePath = url.pathname.replace('/assets/client.js', '');
125
+ break;
126
+ }
127
+ }
128
+ }
129
+
130
+ // Try both with and without base path for WebSocket connection
131
+ let serverUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}${basePath}/ws`;
132
+ logMessage(`Attempting to connect to WebSocket at ${serverUrl}...`);
133
+
134
+ // For Hugging Face Spaces, try the direct /ws path if the base path doesn't work
135
+ const fallbackUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}/ws`;
136
+
137
+ try {
138
+ socket = new WebSocket(serverUrl);
139
+ setupWebSocketHandlers();
140
+
141
+ // Set a timeout to try the fallback URL if the first one doesn't connect
142
+ setTimeout(() => {
143
+ if (socket.readyState !== WebSocket.OPEN && socket.readyState !== WebSocket.CONNECTING) {
144
+ logMessage(`Connection to ${serverUrl} failed. Trying fallback URL: ${fallbackUrl}`);
145
+ socket = new WebSocket(fallbackUrl);
146
+ setupWebSocketHandlers();
147
+ }
148
+ }, 3000);
149
+ } catch (error) {
150
+ logMessage(`Error connecting to WebSocket: ${error.message}`);
151
+ resetUI();
152
+ }
153
+ }
154
+
155
+ // Set up WebSocket event handlers
156
+ function setupWebSocketHandlers() {
157
+ socket.onopen = () => {
158
+ logMessage('WebSocket connection established');
159
+ connectBtn.textContent = 'Disconnect';
160
+ startStreamBtn.disabled = false;
161
+ sceneSelect.disabled = false;
162
+ };
163
+
164
+ socket.onmessage = (event) => {
165
+ const message = JSON.parse(event.data);
166
+
167
+ switch (message.action) {
168
+ case 'welcome':
169
+ userId = message.userId;
170
+ logMessage(`Connected with user ID: ${userId}`);
171
+
172
+ // Update scene options if server provides them
173
+ if (message.scenes && Array.isArray(message.scenes)) {
174
+ sceneSelect.innerHTML = '';
175
+ message.scenes.forEach(scene => {
176
+ const option = document.createElement('option');
177
+ option.value = scene;
178
+ option.textContent = scene.charAt(0).toUpperCase() + scene.slice(1);
179
+ sceneSelect.appendChild(option);
180
+ });
181
+ }
182
+ break;
183
+
184
+ case 'frame':
185
+ // Process incoming frame
186
+ processFrame(message);
187
+ break;
188
+
189
+ case 'start_stream':
190
+ if (message.success) {
191
+ isStreaming = true;
192
+ startStreamBtn.disabled = true;
193
+ stopStreamBtn.disabled = false;
194
+ logMessage(`Streaming started: ${message.message}`);
195
+
196
+ // Start FPS counter
197
+ startFpsCounter();
198
+ } else {
199
+ logMessage(`Error starting stream: ${message.error}`);
200
+ }
201
+ break;
202
+
203
+ case 'stop_stream':
204
+ if (message.success) {
205
+ isStreaming = false;
206
+ startStreamBtn.disabled = false;
207
+ stopStreamBtn.disabled = true;
208
+ logMessage('Streaming stopped');
209
+
210
+ // Stop FPS counter
211
+ stopFpsCounter();
212
+ } else {
213
+ logMessage(`Error stopping stream: ${message.error}`);
214
+ }
215
+ break;
216
+
217
+ case 'pong':
218
+ // Server responded to ping
219
+ break;
220
+
221
+ case 'change_scene':
222
+ if (message.success) {
223
+ logMessage(`Scene changed to ${message.scene}`);
224
+ } else {
225
+ logMessage(`Error changing scene: ${message.error}`);
226
+ }
227
+ break;
228
+
229
+ default:
230
+ logMessage(`Received message: ${JSON.stringify(message)}`);
231
+ }
232
+ };
233
+
234
+ socket.onclose = (event) => {
235
+ logMessage(`WebSocket connection closed (code: ${event.code}, reason: ${event.reason || 'none given'})`);
236
+ resetUI();
237
+ };
238
+
239
+ socket.onerror = (error) => {
240
+ logMessage(`WebSocket error. This is often caused by CORS issues or the server being inaccessible.`);
241
+ console.error('WebSocket error:', error);
242
+ resetUI();
243
+ };
244
+ }
245
+
246
+ // Disconnect from WebSocket server
247
+ function disconnectWebSocket() {
248
+ if (socket && socket.readyState === WebSocket.OPEN) {
249
+ // Stop streaming if active
250
+ if (isStreaming) {
251
+ sendStopStream();
252
+ }
253
+
254
+ // Close the socket
255
+ socket.close();
256
+ logMessage('Disconnected from server');
257
+ }
258
+ }
259
+
260
+ // Start streaming frames
261
+ function sendStartStream() {
262
+ if (socket && socket.readyState === WebSocket.OPEN) {
263
+ socket.send(JSON.stringify({
264
+ action: 'start_stream',
265
+ requestId: generateRequestId(),
266
+ fps: 16 // Default FPS
267
+ }));
268
+ }
269
+ }
270
+
271
+ // Stop streaming frames
272
+ function sendStopStream() {
273
+ if (socket && socket.readyState === WebSocket.OPEN) {
274
+ socket.send(JSON.stringify({
275
+ action: 'stop_stream',
276
+ requestId: generateRequestId()
277
+ }));
278
+ }
279
+ }
280
+
281
+ // Send keyboard input to server
282
+ function sendKeyboardInput(key, pressed) {
283
+ if (socket && socket.readyState === WebSocket.OPEN) {
284
+ socket.send(JSON.stringify({
285
+ action: 'keyboard_input',
286
+ requestId: generateRequestId(),
287
+ key: key,
288
+ pressed: pressed
289
+ }));
290
+ }
291
+ }
292
+
293
+ // Send mouse input to server
294
+ function sendMouseInput(x, y) {
295
+ if (socket && socket.readyState === WebSocket.OPEN && isStreaming) {
296
+ socket.send(JSON.stringify({
297
+ action: 'mouse_input',
298
+ requestId: generateRequestId(),
299
+ x: x,
300
+ y: y
301
+ }));
302
+ }
303
+ }
304
+
305
+ // Change scene
306
+ function sendChangeScene(scene) {
307
+ if (socket && socket.readyState === WebSocket.OPEN) {
308
+ socket.send(JSON.stringify({
309
+ action: 'change_scene',
310
+ requestId: generateRequestId(),
311
+ scene: scene
312
+ }));
313
+ }
314
+ }
315
+
316
+ // Process incoming frame
317
+ function processFrame(message) {
318
+ // Update FPS calculation
319
+ const now = performance.now();
320
+ if (lastFrameTime > 0) {
321
+ frameCount++;
322
+ }
323
+ lastFrameTime = now;
324
+
325
+ // Update the canvas with the new frame
326
+ if (message.frameData) {
327
+ gameCanvas.src = `data:image/jpeg;base64,${message.frameData}`;
328
+ }
329
+ }
330
+
331
+ // Generate a random request ID
332
+ function generateRequestId() {
333
+ return Math.random().toString(36).substring(2, 15);
334
+ }
335
+
336
+ // Log message to the connection info panel
337
+ function logMessage(message) {
338
+ const logEntry = document.createElement('div');
339
+ logEntry.className = 'log-entry';
340
+
341
+ const timestamp = new Date().toLocaleTimeString();
342
+ logEntry.textContent = `[${timestamp}] ${message}`;
343
+
344
+ connectionLog.appendChild(logEntry);
345
+ connectionLog.scrollTop = connectionLog.scrollHeight;
346
+
347
+ // Limit number of log entries
348
+ while (connectionLog.children.length > 100) {
349
+ connectionLog.removeChild(connectionLog.firstChild);
350
+ }
351
+ }
352
+
353
+ // Start FPS counter updates
354
+ function startFpsCounter() {
355
+ frameCount = 0;
356
+ lastFrameTime = 0;
357
+
358
+ // Update FPS display every second
359
+ fpsUpdateInterval = setInterval(() => {
360
+ fpsCounter.textContent = `FPS: ${frameCount}`;
361
+ frameCount = 0;
362
+ }, 1000);
363
+ }
364
+
365
+ // Stop FPS counter updates
366
+ function stopFpsCounter() {
367
+ if (fpsUpdateInterval) {
368
+ clearInterval(fpsUpdateInterval);
369
+ fpsUpdateInterval = null;
370
+ }
371
+ fpsCounter.textContent = 'FPS: 0';
372
+ }
373
+
374
+ // Reset UI to initial state
375
+ function resetUI() {
376
+ connectBtn.textContent = 'Connect';
377
+ startStreamBtn.disabled = true;
378
+ stopStreamBtn.disabled = true;
379
+ sceneSelect.disabled = true;
380
+
381
+ // Reset key indicators
382
+ for (const key in keyElements) {
383
+ keyElements[key].classList.remove('active');
384
+ }
385
+
386
+ // Stop FPS counter
387
+ stopFpsCounter();
388
+
389
+ // Reset streaming state
390
+ isStreaming = false;
391
+ }
392
+
393
+ // Event Listeners
394
+ connectBtn.addEventListener('click', () => {
395
+ if (socket && socket.readyState === WebSocket.OPEN) {
396
+ disconnectWebSocket();
397
+ } else {
398
+ connectWebSocket();
399
+ }
400
+ });
401
+
402
+ startStreamBtn.addEventListener('click', sendStartStream);
403
+ stopStreamBtn.addEventListener('click', sendStopStream);
404
+
405
+ sceneSelect.addEventListener('change', () => {
406
+ sendChangeScene(sceneSelect.value);
407
+ });
408
+
409
+ // Keyboard event listeners
410
+ document.addEventListener('keydown', (event) => {
411
+ const key = event.key.toLowerCase();
412
+
413
+ // Map key to action
414
+ let action = keyToAction[key];
415
+ if (!action && key === ' ') {
416
+ action = keyToAction[' ']; // Handle spacebar
417
+ }
418
+
419
+ if (action && !keyState[action]) {
420
+ keyState[action] = true;
421
+
422
+ // Update visual indicator
423
+ const keyElement = keyElements[key] ||
424
+ (key === ' ' ? keyElements['space'] : null) ||
425
+ (key === 'shift' ? keyElements['shift'] : null);
426
+
427
+ if (keyElement) {
428
+ keyElement.classList.add('active');
429
+ }
430
+
431
+ // Send to server
432
+ sendKeyboardInput(action, true);
433
+ }
434
+
435
+ // Prevent default actions for game controls
436
+ if (Object.keys(keyToAction).includes(key) || key === ' ') {
437
+ event.preventDefault();
438
+ }
439
+ });
440
+
441
+ document.addEventListener('keyup', (event) => {
442
+ const key = event.key.toLowerCase();
443
+
444
+ // Map key to action
445
+ let action = keyToAction[key];
446
+ if (!action && key === ' ') {
447
+ action = keyToAction[' ']; // Handle spacebar
448
+ }
449
+
450
+ if (action && keyState[action]) {
451
+ keyState[action] = false;
452
+
453
+ // Update visual indicator
454
+ const keyElement = keyElements[key] ||
455
+ (key === ' ' ? keyElements['space'] : null) ||
456
+ (key === 'shift' ? keyElements['shift'] : null);
457
+
458
+ if (keyElement) {
459
+ keyElement.classList.remove('active');
460
+ }
461
+
462
+ // Send to server
463
+ sendKeyboardInput(action, false);
464
+ }
465
+ });
466
+
467
+ // Mouse capture functions
468
+ function requestPointerLock() {
469
+ if (!mouseState.captured && pointerLockSupported) {
470
+ mouseTrackingArea.requestPointerLock = mouseTrackingArea.requestPointerLock ||
471
+ mouseTrackingArea.mozRequestPointerLock ||
472
+ mouseTrackingArea.webkitRequestPointerLock;
473
+ mouseTrackingArea.requestPointerLock();
474
+ logMessage('Mouse captured. Press ESC to release.');
475
+ }
476
+ }
477
+
478
+ function exitPointerLock() {
479
+ if (mouseState.captured) {
480
+ document.exitPointerLock = document.exitPointerLock ||
481
+ document.mozExitPointerLock ||
482
+ document.webkitExitPointerLock;
483
+ document.exitPointerLock();
484
+ logMessage('Mouse released.');
485
+ }
486
+ }
487
+
488
+ // Handle pointer lock change events
489
+ document.addEventListener('pointerlockchange', pointerLockChangeHandler);
490
+ document.addEventListener('mozpointerlockchange', pointerLockChangeHandler);
491
+ document.addEventListener('webkitpointerlockchange', pointerLockChangeHandler);
492
+
493
+ function pointerLockChangeHandler() {
494
+ if (document.pointerLockElement === mouseTrackingArea ||
495
+ document.mozPointerLockElement === mouseTrackingArea ||
496
+ document.webkitPointerLockElement === mouseTrackingArea) {
497
+ // Pointer is locked, enable mouse movement tracking
498
+ mouseState.captured = true;
499
+ document.addEventListener('mousemove', handleMouseMovement);
500
+ } else {
501
+ // Pointer is unlocked, disable mouse movement tracking
502
+ mouseState.captured = false;
503
+ document.removeEventListener('mousemove', handleMouseMovement);
504
+ // Reset mouse state
505
+ mouseState.x = 0;
506
+ mouseState.y = 0;
507
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
508
+ throttledSendMouseInput();
509
+ }
510
+ }
511
+
512
+ // Mouse tracking with pointer lock
513
+ function handleMouseMovement(event) {
514
+ if (mouseState.captured) {
515
+ // Use movement for mouse look when captured
516
+ const sensitivity = 0.005; // Adjust sensitivity
517
+ mouseState.x += event.movementX * sensitivity;
518
+ mouseState.y -= event.movementY * sensitivity; // Invert Y for intuitive camera control
519
+
520
+ // Clamp values
521
+ mouseState.x = Math.max(-1, Math.min(1, mouseState.x));
522
+ mouseState.y = Math.max(-1, Math.min(1, mouseState.y));
523
+
524
+ // Update display
525
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
526
+
527
+ // Send to server (throttled)
528
+ throttledSendMouseInput();
529
+ }
530
+ }
531
+
532
+ // Mouse click to capture
533
+ mouseTrackingArea.addEventListener('click', () => {
534
+ if (!mouseState.captured && isStreaming) {
535
+ requestPointerLock();
536
+ }
537
+ });
538
+
539
+ // Standard mouse tracking for when pointer is not locked
540
+ mouseTrackingArea.addEventListener('mousemove', (event) => {
541
+ if (!mouseState.captured) {
542
+ // Calculate normalized coordinates relative to the center of the tracking area
543
+ const rect = mouseTrackingArea.getBoundingClientRect();
544
+ const centerX = rect.width / 2;
545
+ const centerY = rect.height / 2;
546
+
547
+ // Calculate relative position from center (-1 to 1)
548
+ const relX = (event.clientX - rect.left - centerX) / centerX;
549
+ const relY = (event.clientY - rect.top - centerY) / centerY;
550
+
551
+ // Scale down for smoother movement (similar to conditions.py)
552
+ const scaleFactor = 0.05;
553
+ mouseState.x = relX * scaleFactor;
554
+ mouseState.y = -relY * scaleFactor; // Invert Y for intuitive camera control
555
+
556
+ // Update display
557
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
558
+
559
+ // Send to server (throttled)
560
+ throttledSendMouseInput();
561
+ }
562
+ });
563
+
564
+ // Throttle mouse movement to avoid flooding the server
565
+ const throttledSendMouseInput = (() => {
566
+ let lastSentTime = 0;
567
+ const interval = 50; // milliseconds
568
+
569
+ return () => {
570
+ const now = performance.now();
571
+ if (now - lastSentTime >= interval) {
572
+ sendMouseInput(mouseState.x, mouseState.y);
573
+ lastSentTime = now;
574
+ }
575
+ };
576
+ })();
577
+
578
+ // Toggle panel collapse/expand
579
+ function togglePanel(panelId) {
580
+ const panel = document.getElementById(panelId);
581
+ const button = panel.querySelector('.toggle-button');
582
+
583
+ if (panel.classList.contains('collapsed')) {
584
+ // Expand the panel
585
+ panel.classList.remove('collapsed');
586
+ button.textContent = '−'; // Minus sign
587
+ } else {
588
+ // Collapse the panel
589
+ panel.classList.add('collapsed');
590
+ button.textContent = '+'; // Plus sign
591
+ }
592
+ }
593
+
594
+ // Initialize the UI
595
+ resetUI();
596
+
597
+ // Make panel headers clickable
598
+ document.querySelectorAll('.panel-header').forEach(header => {
599
+ header.addEventListener('click', () => {
600
+ const panelId = header.parentElement.id;
601
+ togglePanel(panelId);
602
+ });
603
+ });
example/engine.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ MatrixGame Engine
6
+
7
+ This module handles the core rendering and model inference for the MatrixGame project.
8
+ """
9
+
10
+ import os
11
+ import logging
12
+ import argparse
13
+ import time
14
+ import torch
15
+ import numpy as np
16
+ from PIL import Image
17
+ import cv2
18
+ from einops import rearrange
19
+ from diffusers.utils import load_image
20
+ from diffusers.video_processor import VideoProcessor
21
+ from typing import Dict, List, Tuple, Any, Optional, Union
22
+
23
+ # MatrixGame specific imports
24
+ from matrixgame.sample.pipeline_matrixgame import MatrixGameVideoPipeline
25
+ from matrixgame.model_variants import get_dit
26
+ from matrixgame.vae_variants import get_vae
27
+ from matrixgame.encoder_variants import get_text_enc
28
+ from matrixgame.model_variants.matrixgame_dit_src import MGVideoDiffusionTransformerI2V
29
+ from matrixgame.sample.flow_matching_scheduler_matrixgame import FlowMatchDiscreteScheduler
30
+ from teacache_forward import teacache_forward
31
+
32
+ # Import utility functions
33
+ from utils import (
34
+ visualize_controls,
35
+ frame_to_jpeg,
36
+ load_scene_frames,
37
+ logger
38
+ )
39
+
40
+ class MatrixGameEngine:
41
+ """
42
+ Core engine for MatrixGame model inference and frame generation.
43
+ """
44
+ def __init__(self, args: Optional[argparse.Namespace] = None):
45
+ """
46
+ Initialize the MatrixGame engine with configuration parameters.
47
+
48
+ Args:
49
+ args: Optional parsed command line arguments for model configuration
50
+ """
51
+ # Set default parameters if args not provided
52
+ self.frame_width = getattr(args, 'frame_width', 640)
53
+ self.frame_height = getattr(args, 'frame_height', 360)
54
+ self.fps = getattr(args, 'fps', 16)
55
+ self.inference_steps = getattr(args, 'inference_steps', 20)
56
+ self.guidance_scale = getattr(args, 'guidance_scale', 6.0)
57
+ self.num_pre_frames = getattr(args, 'num_pre_frames', 3)
58
+
59
+ # Initialize state
60
+ self.frame_count = 0
61
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62
+ self.weight_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
63
+
64
+ # Model paths from environment or args
65
+ self.vae_path = os.environ.get("VAE_PATH", "./models/matrixgame/vae/")
66
+ self.dit_path = os.environ.get("DIT_PATH", "./models/matrixgame/dit/")
67
+ self.textenc_path = os.environ.get("TEXTENC_PATH", "./models/matrixgame")
68
+
69
+ # Cache scene initial frames
70
+ self.scenes = {
71
+ 'forest': load_scene_frames('forest', self.frame_width, self.frame_height),
72
+ 'desert': load_scene_frames('desert', self.frame_width, self.frame_height),
73
+ 'beach': load_scene_frames('beach', self.frame_width, self.frame_height),
74
+ 'hills': load_scene_frames('hills', self.frame_width, self.frame_height),
75
+ 'river': load_scene_frames('river', self.frame_width, self.frame_height),
76
+ 'icy': load_scene_frames('icy', self.frame_width, self.frame_height),
77
+ 'mushroom': load_scene_frames('mushroom', self.frame_width, self.frame_height),
78
+ 'plain': load_scene_frames('plain', self.frame_width, self.frame_height)
79
+ }
80
+
81
+ # Cache initial images for model input
82
+ self.scene_initial_images = {}
83
+
84
+ # Initialize MatrixGame pipeline
85
+ self.model_loaded = False
86
+ if torch.cuda.is_available():
87
+ try:
88
+ self._init_models()
89
+ self.model_loaded = True
90
+ logger.info("MatrixGame models loaded successfully")
91
+ except Exception as e:
92
+ logger.error(f"Failed to initialize MatrixGame models: {str(e)}")
93
+ logger.info("Falling back to frame cycling mode")
94
+ else:
95
+ logger.warning("CUDA not available. Using frame cycling mode only.")
96
+
97
+ def _init_models(self):
98
+ """Initialize MatrixGame models (VAE, text encoder, transformer)"""
99
+ # Initialize flow matching scheduler
100
+ self.scheduler = FlowMatchDiscreteScheduler(
101
+ shift=15.0,
102
+ reverse=True,
103
+ solver="euler"
104
+ )
105
+
106
+ # Initialize VAE
107
+ try:
108
+ self.vae = get_vae("matrixgame", self.vae_path, self.weight_dtype)
109
+ self.vae.requires_grad_(False)
110
+ self.vae.eval()
111
+ self.vae.enable_tiling()
112
+ logger.info("VAE model loaded successfully")
113
+ except Exception as e:
114
+ logger.error(f"Error loading VAE model: {str(e)}")
115
+ raise
116
+
117
+ # Initialize DIT (Transformer)
118
+ try:
119
+ dit = MGVideoDiffusionTransformerI2V.from_pretrained(self.dit_path)
120
+ dit.requires_grad_(False)
121
+ dit.eval()
122
+ logger.info("DIT model loaded successfully")
123
+ except Exception as e:
124
+ logger.error(f"Error loading DIT model: {str(e)}")
125
+ raise
126
+
127
+ # Initialize text encoder
128
+ try:
129
+ self.text_enc = get_text_enc('matrixgame', self.textenc_path, weight_dtype=self.weight_dtype, i2v_type='refiner')
130
+ logger.info("Text encoder loaded successfully")
131
+ except Exception as e:
132
+ logger.error(f"Error loading text encoder: {str(e)}")
133
+ raise
134
+
135
+ # Initialize pipeline
136
+ try:
137
+ self.pipeline = MatrixGameVideoPipeline(
138
+ vae=self.vae.vae,
139
+ text_encoder=self.text_enc,
140
+ transformer=dit,
141
+ scheduler=self.scheduler,
142
+ ).to(self.weight_dtype).to(self.device)
143
+ logger.info("Pipeline initialized successfully")
144
+ except Exception as e:
145
+ logger.error(f"Error initializing pipeline: {str(e)}")
146
+ raise
147
+
148
+ # Configure teacache for the transformer
149
+ self.pipeline.transformer.__class__.enable_teacache = True
150
+ self.pipeline.transformer.__class__.cnt = 0
151
+ self.pipeline.transformer.__class__.num_steps = self.inference_steps
152
+ self.pipeline.transformer.__class__.accumulated_rel_l1_distance = 0
153
+ self.pipeline.transformer.__class__.rel_l1_thresh = 0.075
154
+ self.pipeline.transformer.__class__.previous_modulated_input = None
155
+ self.pipeline.transformer.__class__.previous_residual = None
156
+ self.pipeline.transformer.__class__.forward = teacache_forward
157
+
158
+ # Preprocess initial images for all scenes
159
+ for scene_name, frames in self.scenes.items():
160
+ if frames:
161
+ # Use first frame as initial image
162
+ self.scene_initial_images[scene_name] = self._preprocess_image(frames[0])
163
+
164
+ def _preprocess_image(self, image_array: np.ndarray) -> torch.Tensor:
165
+ """
166
+ Preprocess an image for the model.
167
+
168
+ Args:
169
+ image_array: Input image as numpy array
170
+
171
+ Returns:
172
+ torch.Tensor: Preprocessed image tensor
173
+ """
174
+ # Convert numpy array to PIL Image if needed
175
+ if isinstance(image_array, np.ndarray):
176
+ image = Image.fromarray(image_array)
177
+ else:
178
+ image = image_array
179
+
180
+ # Preprocess for VAE
181
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, 'vae') else 8
182
+ video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor)
183
+ initial_image = video_processor.preprocess(image, height=self.frame_height, width=self.frame_width)
184
+
185
+ # Add past frames for stability (use same frame repeated)
186
+ past_frames = initial_image.repeat(self.num_pre_frames, 1, 1, 1)
187
+ initial_image = torch.cat([initial_image, past_frames], dim=0)
188
+
189
+ return initial_image
190
+
191
+ def generate_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
192
+ mouse_condition: Optional[List] = None) -> bytes:
193
+ """
194
+ Generate the next frame based on current conditions using MatrixGame model.
195
+
196
+ Args:
197
+ scene_name: Name of the current scene
198
+ keyboard_condition: Keyboard input state
199
+ mouse_condition: Mouse input state
200
+
201
+ Returns:
202
+ bytes: JPEG bytes of the frame
203
+ """
204
+ # Check if model is loaded
205
+ if not self.model_loaded or not torch.cuda.is_available():
206
+ # Fall back to frame cycling for demo mode or if models failed to load
207
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
208
+ else:
209
+ # Use MatrixGame model for frame generation
210
+ try:
211
+ # Get initial image for this scene
212
+ initial_image = self.scene_initial_images.get(scene_name)
213
+ if initial_image is None:
214
+ # Use forest as default if we don't have an initial image for this scene
215
+ initial_image = self.scene_initial_images.get('forest')
216
+ if initial_image is None:
217
+ # If we still don't have an initial image, fall back to frame cycling
218
+ logger.error(f"No initial image available for scene {scene_name}")
219
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
220
+
221
+ # Prepare input tensors (move to device and format correctly)
222
+ if keyboard_condition is None:
223
+ keyboard_condition = [[0, 0, 0, 0, 0, 0]]
224
+ if mouse_condition is None:
225
+ mouse_condition = [[0, 0]]
226
+
227
+ # Convert conditions to tensors
228
+ keyboard_tensor = torch.tensor(keyboard_condition, dtype=torch.float32)
229
+ mouse_tensor = torch.tensor(mouse_condition, dtype=torch.float32)
230
+
231
+ # Move to device and convert to correct dtype
232
+ keyboard_tensor = keyboard_tensor.to(self.weight_dtype).to(self.device)
233
+ mouse_tensor = mouse_tensor.to(self.weight_dtype).to(self.device)
234
+
235
+ # Get the first frame from the scene for semantic conditioning
236
+ scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
237
+ if not scene_frames:
238
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
239
+
240
+ semantic_image = Image.fromarray(scene_frames[0])
241
+
242
+ # Get PIL image version of the frame for visualization
243
+ for scene_frame in scene_frames:
244
+ if isinstance(scene_frame, np.ndarray):
245
+ semantic_image = Image.fromarray(scene_frame)
246
+ break
247
+
248
+ # Generate a single frame with the model
249
+ # Use fewer inference steps for interactive frame generation
250
+ with torch.no_grad():
251
+ # Generate a short video (we'll just use the first frame)
252
+ # We're using a short length (3 frames) for real-time performance
253
+ video = self.pipeline(
254
+ height=self.frame_height,
255
+ width=self.frame_width,
256
+ video_length=3, # Generate a very short video for speed
257
+ mouse_condition=mouse_tensor,
258
+ keyboard_condition=keyboard_tensor,
259
+ initial_image=initial_image,
260
+ num_inference_steps=self.inference_steps,
261
+ guidance_scale=self.guidance_scale,
262
+ embedded_guidance_scale=None,
263
+ data_type="video",
264
+ vae_ver='884-16c-hy',
265
+ enable_tiling=True,
266
+ generator=torch.Generator(device=self.device).manual_seed(42),
267
+ i2v_type='refiner',
268
+ semantic_images=semantic_image
269
+ ).videos[0]
270
+
271
+ # Convert video tensor to numpy array (use first frame)
272
+ video_frame = video[0].permute(1, 2, 0).cpu().numpy()
273
+ video_frame = (video_frame * 255).astype(np.uint8)
274
+ frame = video_frame
275
+
276
+ # Increment frame counter
277
+ self.frame_count += 1
278
+
279
+ except Exception as e:
280
+ logger.error(f"Error generating frame with MatrixGame model: {str(e)}")
281
+ # Fall back to cycling demo frames if model generation fails
282
+ return self._fallback_frame(scene_name, keyboard_condition, mouse_condition)
283
+
284
+ # Add visualization of input controls
285
+ frame = visualize_controls(
286
+ frame, keyboard_condition, mouse_condition,
287
+ self.frame_width, self.frame_height
288
+ )
289
+
290
+ # Convert frame to JPEG
291
+ return frame_to_jpeg(frame, self.frame_height, self.frame_width)
292
+
293
+ def _fallback_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
294
+ mouse_condition: Optional[List] = None) -> bytes:
295
+ """
296
+ Generate a fallback frame when model generation fails.
297
+
298
+ Args:
299
+ scene_name: Name of the current scene
300
+ keyboard_condition: Keyboard input state
301
+ mouse_condition: Mouse input state
302
+
303
+ Returns:
304
+ bytes: JPEG bytes of the frame
305
+ """
306
+ scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
307
+ frame_idx = self.frame_count % len(scene_frames)
308
+ frame = scene_frames[frame_idx].copy()
309
+ self.frame_count += 1
310
+
311
+ # Add fallback mode indicator
312
+ cv2.putText(frame, "Fallback mode",
313
+ (10, self.frame_height - 20),
314
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
315
+
316
+ # Add visualization of input controls
317
+ frame = visualize_controls(
318
+ frame, keyboard_condition, mouse_condition,
319
+ self.frame_width, self.frame_height
320
+ )
321
+
322
+ # Convert frame to JPEG
323
+ return frame_to_jpeg(frame, self.frame_height, self.frame_width)
324
+
325
+ def get_valid_scenes(self) -> List[str]:
326
+ """
327
+ Get a list of valid scene names.
328
+
329
+ Returns:
330
+ List[str]: List of valid scene names
331
+ """
332
+ return list(self.scenes.keys())
example/index.html ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>MatrixGame Client</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ margin: 0;
11
+ padding: 0;
12
+ background-color: #121212;
13
+ color: #e0e0e0;
14
+ display: flex;
15
+ flex-direction: column;
16
+ align-items: center;
17
+ user-select: none; /* Disable text selection */
18
+ -webkit-user-select: none;
19
+ -moz-user-select: none;
20
+ -ms-user-select: none;
21
+ overflow-x: hidden;
22
+ }
23
+
24
+ .container {
25
+ width: 100%;
26
+ max-width: 100%;
27
+ display: flex;
28
+ flex-direction: column;
29
+ align-items: center;
30
+ }
31
+
32
+ .game-area {
33
+ display: flex;
34
+ flex-direction: column;
35
+ align-items: center;
36
+ width: 100%;
37
+ max-height: 85vh;
38
+ margin: 0;
39
+ position: relative;
40
+ }
41
+
42
+ #mouse-tracking-area {
43
+ position: relative;
44
+ width: 100%;
45
+ height: auto;
46
+ cursor: pointer; /* Show cursor as pointer to encourage clicks */
47
+ display: flex;
48
+ justify-content: center;
49
+ align-items: center;
50
+ max-height: 85vh;
51
+ }
52
+
53
+ #game-canvas {
54
+ width: 100%;
55
+ height: auto;
56
+ max-height: 85vh;
57
+ object-fit: contain;
58
+ background-color: #000;
59
+ pointer-events: none; /* Prevent drag on the image */
60
+ -webkit-user-drag: none;
61
+ -khtml-user-drag: none;
62
+ -moz-user-drag: none;
63
+ -o-user-drag: none;
64
+ user-drag: none;
65
+ }
66
+
67
+ .controls {
68
+ display: flex;
69
+ justify-content: space-between;
70
+ width: 100%;
71
+ max-width: 1200px;
72
+ padding: 10px;
73
+ background-color: rgba(0, 0, 0, 0.5);
74
+ position: absolute;
75
+ bottom: 0;
76
+ z-index: 10;
77
+ box-sizing: border-box;
78
+ }
79
+
80
+ .panels-container {
81
+ display: flex;
82
+ width: 100%;
83
+ max-width: 1200px;
84
+ margin: 10px auto;
85
+ gap: 10px;
86
+ }
87
+
88
+ .panel {
89
+ flex: 1;
90
+ background-color: #1E1E1E;
91
+ border-radius: 5px;
92
+ overflow: hidden;
93
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
94
+ transition: height 0.3s ease;
95
+ }
96
+
97
+ .panel-header {
98
+ background-color: #272727;
99
+ padding: 10px 15px;
100
+ display: flex;
101
+ justify-content: space-between;
102
+ align-items: center;
103
+ cursor: pointer;
104
+ }
105
+
106
+ .panel-title {
107
+ font-weight: bold;
108
+ color: #4CAF50;
109
+ }
110
+
111
+ .toggle-button {
112
+ background: none;
113
+ border: none;
114
+ color: #e0e0e0;
115
+ font-size: 18px;
116
+ cursor: pointer;
117
+ }
118
+
119
+ .toggle-button:focus {
120
+ outline: none;
121
+ }
122
+
123
+ .panel-content {
124
+ padding: 15px;
125
+ max-height: 300px;
126
+ overflow-y: auto;
127
+ transition: all 0.3s ease;
128
+ }
129
+
130
+ .collapsed .panel-content {
131
+ max-height: 0;
132
+ padding-top: 0;
133
+ padding-bottom: 0;
134
+ overflow: hidden;
135
+ }
136
+
137
+ button {
138
+ background-color: #4CAF50;
139
+ color: white;
140
+ border: none;
141
+ padding: 10px 15px;
142
+ text-align: center;
143
+ text-decoration: none;
144
+ display: inline-block;
145
+ font-size: 14px;
146
+ border-radius: 5px;
147
+ cursor: pointer;
148
+ margin: 5px;
149
+ transition: background-color 0.3s;
150
+ }
151
+
152
+ button:hover {
153
+ background-color: #45a049;
154
+ }
155
+
156
+ button:disabled {
157
+ background-color: #cccccc;
158
+ cursor: not-allowed;
159
+ }
160
+
161
+ select {
162
+ padding: 10px;
163
+ border-radius: 5px;
164
+ background-color: #2A2A2A;
165
+ color: #e0e0e0;
166
+ border: 1px solid #4CAF50;
167
+ }
168
+
169
+ .status {
170
+ margin-top: 10px;
171
+ color: #4CAF50;
172
+ }
173
+
174
+ .key-indicators {
175
+ display: flex;
176
+ justify-content: center;
177
+ margin-top: 15px;
178
+ }
179
+
180
+ .key {
181
+ width: 40px;
182
+ height: 40px;
183
+ margin: 0 5px;
184
+ background-color: #2A2A2A;
185
+ border: 1px solid #444;
186
+ border-radius: 5px;
187
+ display: flex;
188
+ justify-content: center;
189
+ align-items: center;
190
+ font-weight: bold;
191
+ transition: background-color 0.2s;
192
+ }
193
+
194
+ .key.active {
195
+ background-color: #4CAF50;
196
+ color: white;
197
+ }
198
+
199
+ .key-row {
200
+ display: flex;
201
+ justify-content: center;
202
+ margin: 5px 0;
203
+ }
204
+
205
+ .spacebar {
206
+ width: 150px;
207
+ }
208
+
209
+ .connection-info {
210
+ font-family: monospace;
211
+ height: 100%;
212
+ overflow-y: auto;
213
+ }
214
+
215
+ .log-entry {
216
+ margin: 5px 0;
217
+ padding: 3px;
218
+ border-bottom: 1px solid #333;
219
+ }
220
+
221
+ .fps-counter {
222
+ position: absolute;
223
+ top: 10px;
224
+ right: 10px;
225
+ background-color: rgba(0,0,0,0.5);
226
+ color: #4CAF50;
227
+ padding: 5px;
228
+ border-radius: 3px;
229
+ font-family: monospace;
230
+ z-index: 20;
231
+ }
232
+
233
+
234
+ #mouse-position {
235
+ position: absolute;
236
+ top: 10px;
237
+ left: 10px;
238
+ background-color: rgba(0,0,0,0.5);
239
+ color: #4CAF50;
240
+ padding: 5px;
241
+ border-radius: 3px;
242
+ font-family: monospace;
243
+ z-index: 20;
244
+ }
245
+
246
+ @media (max-width: 768px) {
247
+ .panels-container {
248
+ flex-direction: column;
249
+ }
250
+ }
251
+ </style>
252
+ </head>
253
+ <body>
254
+ <div class="container">
255
+ <div class="game-area">
256
+ <div id="mouse-tracking-area">
257
+ <img id="game-canvas" src="" alt="Game Frame">
258
+ <div id="mouse-position">Mouse: 0.00, 0.00</div>
259
+ <div class="fps-counter" id="fps-counter">FPS: 0</div>
260
+ </div>
261
+
262
+ <div class="controls">
263
+ <button id="connect-btn">Connect</button>
264
+ <button id="start-stream-btn" disabled>Start Stream</button>
265
+ <button id="stop-stream-btn" disabled>Stop Stream</button>
266
+ <select id="scene-select" disabled>
267
+ <option value="forest">Forest</option>
268
+ <option value="desert">Desert</option>
269
+ <option value="beach">Beach</option>
270
+ <option value="hills">Hills</option>
271
+ <option value="river">River</option>
272
+ <option value="icy">Icy</option>
273
+ <option value="mushroom">Mushroom</option>
274
+ <option value="plain">Plain</option>
275
+ </select>
276
+ </div>
277
+ </div>
278
+
279
+ <div class="panels-container">
280
+ <!-- Controls Panel -->
281
+ <div class="panel" id="controls-panel">
282
+ <div class="panel-header" onclick="togglePanel('controls-panel')">
283
+ <div class="panel-title">Keyboard Controls</div>
284
+ <button class="toggle-button">−</button>
285
+ </div>
286
+ <div class="panel-content">
287
+ <div class="key-indicators">
288
+ <div class="key-row">
289
+ <div id="key-w" class="key">W</div>
290
+ </div>
291
+ <div class="key-row">
292
+ <div id="key-a" class="key">A</div>
293
+ <div id="key-s" class="key">S</div>
294
+ <div id="key-d" class="key">D</div>
295
+ </div>
296
+ <div class="key-row">
297
+ <div id="key-space" class="key spacebar">SPACE</div>
298
+ </div>
299
+ <div class="key-row">
300
+ <div id="key-shift" class="key">SHIFT</div>
301
+ </div>
302
+ </div>
303
+ <p class="status">
304
+ W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right<br>
305
+ Space = Jump, Shift = Attack<br>
306
+ Click on game view to capture mouse (ESC to release)<br>
307
+ Mouse = Look around
308
+ </p>
309
+ </div>
310
+ </div>
311
+
312
+ <!-- Connection Log Panel -->
313
+ <div class="panel" id="log-panel">
314
+ <div class="panel-header" onclick="togglePanel('log-panel')">
315
+ <div class="panel-title">Connection Log</div>
316
+ <button class="toggle-button">−</button>
317
+ </div>
318
+ <div class="panel-content">
319
+ <div class="connection-info" id="connection-log">
320
+ <div class="log-entry">Waiting to connect...</div>
321
+ </div>
322
+ </div>
323
+ </div>
324
+ </div>
325
+ </div>
326
+
327
+ <script src="./assets/client.js"></script>
328
+ </body>
329
+ </html>
example/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.32.2
2
+ einops==0.8.1
3
+
4
+ #flash_attn==2.7.4.post1
5
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
6
+
7
+ ftfy==6.3.1
8
+ imageio==2.34.0
9
+ numpy==1.24.4
10
+ opencv_python==4.9.0.80
11
+ opencv_python_headless==4.9.0.80
12
+ packaging==25.0
13
+ peft==0.14.0
14
+ Pillow==11.2.1
15
+ regex==2024.11.6
16
+ safetensors==0.5.3
17
+ torch==2.5.1
18
+ torchvision==0.20.1
19
+ torchaudio==2.5.1
20
+ transformers==4.47.1
21
+ aiohttp==3.9.3
22
+ jinja2==3.1.3
23
+ python-multipart==0.0.6
example/server.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ MatrixGame Websocket Gaming Server
6
+
7
+ This script implements a websocket server for the MatrixGame project,
8
+ allowing real-time streaming of game frames based on player inputs.
9
+ """
10
+
11
+ import asyncio
12
+ import json
13
+ import logging
14
+ import os
15
+ import pathlib
16
+ import time
17
+ import uuid
18
+ import base64
19
+ import argparse
20
+ from typing import Dict, List, Any, Optional
21
+ from aiohttp import web, WSMsgType
22
+
23
+ # Import the game engine
24
+ from engine import MatrixGameEngine
25
+ from utils import logger, parse_model_args, setup_gpu_environment
26
+
27
+ class GameSession:
28
+ """
29
+ Represents a user's gaming session.
30
+ Each WebSocket connection gets its own session with separate queues.
31
+ """
32
+ def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager):
33
+ self.user_id = user_id
34
+ self.ws = ws
35
+ self.game_manager = game_manager
36
+
37
+ # Create action queue for this user session
38
+ self.action_queue = asyncio.Queue()
39
+
40
+ # Session creation time
41
+ self.created_at = time.time()
42
+ self.last_activity = time.time()
43
+
44
+ # Game state
45
+ self.current_scene = "forest" # Default scene
46
+ self.is_streaming = False
47
+ self.stream_task = None
48
+
49
+ # Current input state
50
+ self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack
51
+ self.mouse_state = [0, 0] # x, y
52
+
53
+ self.background_tasks = []
54
+
55
+ async def start(self):
56
+ """Start all the queue processors for this session"""
57
+ self.background_tasks = [
58
+ asyncio.create_task(self._process_action_queue()),
59
+ ]
60
+ logger.info(f"Started game session for user {self.user_id}")
61
+
62
+ async def stop(self):
63
+ """Stop all background tasks for this session"""
64
+ # Stop streaming if active
65
+ if self.is_streaming and self.stream_task:
66
+ self.is_streaming = False
67
+ self.stream_task.cancel()
68
+ try:
69
+ await self.stream_task
70
+ except asyncio.CancelledError:
71
+ pass
72
+
73
+ # Cancel other background tasks
74
+ for task in self.background_tasks:
75
+ task.cancel()
76
+
77
+ try:
78
+ # Wait for tasks to complete cancellation
79
+ await asyncio.gather(*self.background_tasks, return_exceptions=True)
80
+ except asyncio.CancelledError:
81
+ pass
82
+
83
+ logger.info(f"Stopped game session for user {self.user_id}")
84
+
85
+ async def _process_action_queue(self):
86
+ """Process game actions from the queue"""
87
+ while True:
88
+ data = await self.action_queue.get()
89
+ try:
90
+ action_type = data.get('action')
91
+
92
+ if action_type == 'start_stream':
93
+ result = await self._handle_start_stream(data)
94
+ elif action_type == 'stop_stream':
95
+ result = await self._handle_stop_stream(data)
96
+ elif action_type == 'keyboard_input':
97
+ result = await self._handle_keyboard_input(data)
98
+ elif action_type == 'mouse_input':
99
+ result = await self._handle_mouse_input(data)
100
+ elif action_type == 'change_scene':
101
+ result = await self._handle_scene_change(data)
102
+ else:
103
+ result = {
104
+ 'action': action_type,
105
+ 'requestId': data.get('requestId'),
106
+ 'success': False,
107
+ 'error': f'Unknown action: {action_type}'
108
+ }
109
+
110
+ # Send response back to the client
111
+ await self.ws.send_json(result)
112
+
113
+ # Update last activity time
114
+ self.last_activity = time.time()
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error processing action for user {self.user_id}: {str(e)}")
118
+ try:
119
+ await self.ws.send_json({
120
+ 'action': data.get('action'),
121
+ 'requestId': data.get('requestId', 'unknown'),
122
+ 'success': False,
123
+ 'error': f'Error processing action: {str(e)}'
124
+ })
125
+ except Exception as send_error:
126
+ logger.error(f"Error sending error response: {send_error}")
127
+ finally:
128
+ self.action_queue.task_done()
129
+
130
+ async def _handle_start_stream(self, data: Dict) -> Dict:
131
+ """Handle request to start streaming frames"""
132
+ if self.is_streaming:
133
+ return {
134
+ 'action': 'start_stream',
135
+ 'requestId': data.get('requestId'),
136
+ 'success': False,
137
+ 'error': 'Stream already active'
138
+ }
139
+
140
+ fps = data.get('fps', 16)
141
+ self.is_streaming = True
142
+ self.stream_task = asyncio.create_task(self._stream_frames(fps))
143
+
144
+ return {
145
+ 'action': 'start_stream',
146
+ 'requestId': data.get('requestId'),
147
+ 'success': True,
148
+ 'message': f'Streaming started at {fps} FPS'
149
+ }
150
+
151
+ async def _handle_stop_stream(self, data: Dict) -> Dict:
152
+ """Handle request to stop streaming frames"""
153
+ if not self.is_streaming:
154
+ return {
155
+ 'action': 'stop_stream',
156
+ 'requestId': data.get('requestId'),
157
+ 'success': False,
158
+ 'error': 'No active stream to stop'
159
+ }
160
+
161
+ self.is_streaming = False
162
+ if self.stream_task:
163
+ self.stream_task.cancel()
164
+ try:
165
+ await self.stream_task
166
+ except asyncio.CancelledError:
167
+ pass
168
+ self.stream_task = None
169
+
170
+ return {
171
+ 'action': 'stop_stream',
172
+ 'requestId': data.get('requestId'),
173
+ 'success': True,
174
+ 'message': 'Streaming stopped'
175
+ }
176
+
177
+ async def _handle_keyboard_input(self, data: Dict) -> Dict:
178
+ """Handle keyboard input from client"""
179
+ key = data.get('key', '')
180
+ pressed = data.get('pressed', False)
181
+
182
+ # Map key to keyboard state index
183
+ key_map = {
184
+ 'w': 0, 'forward': 0,
185
+ 's': 1, 'back': 1, 'backward': 1,
186
+ 'a': 2, 'left': 2,
187
+ 'd': 3, 'right': 3,
188
+ 'space': 4, 'jump': 4,
189
+ 'shift': 5, 'attack': 5, 'ctrl': 5
190
+ }
191
+
192
+ if key.lower() in key_map:
193
+ key_idx = key_map[key.lower()]
194
+ self.keyboard_state[key_idx] = 1 if pressed else 0
195
+
196
+ return {
197
+ 'action': 'keyboard_input',
198
+ 'requestId': data.get('requestId'),
199
+ 'success': True,
200
+ 'keyboardState': self.keyboard_state
201
+ }
202
+
203
+ async def _handle_mouse_input(self, data: Dict) -> Dict:
204
+ """Handle mouse movement/input from client"""
205
+ mouse_x = data.get('x', 0)
206
+ mouse_y = data.get('y', 0)
207
+
208
+ # Update mouse state, normalize values between -1 and 1
209
+ self.mouse_state = [float(mouse_x), float(mouse_y)]
210
+
211
+ return {
212
+ 'action': 'mouse_input',
213
+ 'requestId': data.get('requestId'),
214
+ 'success': True,
215
+ 'mouseState': self.mouse_state
216
+ }
217
+
218
+ async def _handle_scene_change(self, data: Dict) -> Dict:
219
+ """Handle scene change requests"""
220
+ scene_name = data.get('scene', 'forest')
221
+ valid_scenes = self.game_manager.valid_scenes
222
+
223
+ if scene_name not in valid_scenes:
224
+ return {
225
+ 'action': 'change_scene',
226
+ 'requestId': data.get('requestId'),
227
+ 'success': False,
228
+ 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}'
229
+ }
230
+
231
+ self.current_scene = scene_name
232
+
233
+ return {
234
+ 'action': 'change_scene',
235
+ 'requestId': data.get('requestId'),
236
+ 'success': True,
237
+ 'scene': scene_name
238
+ }
239
+
240
+ async def _stream_frames(self, fps: int):
241
+ """Stream frames to the client at the specified FPS"""
242
+ frame_interval = 1.0 / fps # Time between frames in seconds
243
+
244
+ try:
245
+ while self.is_streaming:
246
+ start_time = time.time()
247
+
248
+ # Generate frame based on current keyboard and mouse state
249
+ keyboard_condition = [self.keyboard_state]
250
+ mouse_condition = [self.mouse_state]
251
+
252
+ # Use the engine to generate the next frame
253
+ frame_bytes = self.game_manager.engine.generate_frame(
254
+ self.current_scene, keyboard_condition, mouse_condition
255
+ )
256
+
257
+ # Encode as base64 for sending in JSON
258
+ frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
259
+
260
+ # Send frame to client
261
+ await self.ws.send_json({
262
+ 'action': 'frame',
263
+ 'frameData': frame_base64,
264
+ 'timestamp': time.time()
265
+ })
266
+
267
+ # Calculate sleep time to maintain FPS
268
+ elapsed = time.time() - start_time
269
+ sleep_time = max(0, frame_interval - elapsed)
270
+ await asyncio.sleep(sleep_time)
271
+
272
+ except asyncio.CancelledError:
273
+ logger.info(f"Frame streaming cancelled for user {self.user_id}")
274
+ except Exception as e:
275
+ logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}")
276
+ if self.ws.closed:
277
+ logger.info(f"WebSocket closed for user {self.user_id}")
278
+ return
279
+
280
+ # Notify client of error
281
+ try:
282
+ await self.ws.send_json({
283
+ 'action': 'frame_error',
284
+ 'error': f'Streaming error: {str(e)}'
285
+ })
286
+ except:
287
+ pass
288
+
289
+ # Stop streaming
290
+ self.is_streaming = False
291
+
292
+ class GameManager:
293
+ """
294
+ Manages all active gaming sessions and shared resources.
295
+ """
296
+ def __init__(self, args: argparse.Namespace):
297
+ self.sessions = {}
298
+ self.session_lock = asyncio.Lock()
299
+
300
+ # Initialize game engine
301
+ self.engine = MatrixGameEngine(args)
302
+
303
+ # Load valid scenes from engine
304
+ self.valid_scenes = self.engine.get_valid_scenes()
305
+
306
+ async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession:
307
+ """Create a new game session"""
308
+ async with self.session_lock:
309
+ # Create a new session for this user
310
+ session = GameSession(user_id, ws, self)
311
+ await session.start()
312
+ self.sessions[user_id] = session
313
+ return session
314
+
315
+ async def delete_session(self, user_id: str) -> None:
316
+ """Delete a game session and clean up resources"""
317
+ async with self.session_lock:
318
+ if user_id in self.sessions:
319
+ session = self.sessions[user_id]
320
+ await session.stop()
321
+ del self.sessions[user_id]
322
+ logger.info(f"Deleted game session for user {user_id}")
323
+
324
+ def get_session(self, user_id: str) -> Optional[GameSession]:
325
+ """Get a game session if it exists"""
326
+ return self.sessions.get(user_id)
327
+
328
+ async def close_all_sessions(self) -> None:
329
+ """Close all active sessions (used during shutdown)"""
330
+ async with self.session_lock:
331
+ for user_id, session in list(self.sessions.items()):
332
+ await session.stop()
333
+ self.sessions.clear()
334
+ logger.info("Closed all active game sessions")
335
+
336
+ @property
337
+ def session_count(self) -> int:
338
+ """Get the number of active sessions"""
339
+ return len(self.sessions)
340
+
341
+ def get_session_stats(self) -> Dict:
342
+ """Get statistics about active sessions"""
343
+ stats = {
344
+ 'total_sessions': len(self.sessions),
345
+ 'active_scenes': {},
346
+ 'streaming_sessions': 0
347
+ }
348
+
349
+ # Count sessions by scene and streaming status
350
+ for session in self.sessions.values():
351
+ scene = session.current_scene
352
+ stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1
353
+ if session.is_streaming:
354
+ stats['streaming_sessions'] += 1
355
+
356
+ return stats
357
+
358
+ # Create global game manager
359
+ game_manager = None
360
+
361
+ async def status_handler(request: web.Request) -> web.Response:
362
+ """Handler for API status endpoint"""
363
+ # Get session statistics
364
+ session_stats = game_manager.get_session_stats()
365
+
366
+ return web.json_response({
367
+ 'product': 'MatrixGame WebSocket Server',
368
+ 'version': '1.0.0',
369
+ 'active_sessions': session_stats,
370
+ 'available_scenes': game_manager.valid_scenes
371
+ })
372
+
373
+ async def root_handler(request: web.Request) -> web.Response:
374
+ """Handler for serving the client at the root path"""
375
+ client_path = pathlib.Path(__file__).parent / 'client' / 'index.html'
376
+
377
+ with open(client_path, 'r') as file:
378
+ html_content = file.read()
379
+
380
+ return web.Response(text=html_content, content_type='text/html')
381
+
382
+ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
383
+ """Handle WebSocket connections with robust error handling"""
384
+ logger.info(f"WebSocket connection attempt - PATH: {request.path}, QUERY: {request.query_string}")
385
+
386
+ # Log request headers at debug level only (could contain sensitive information)
387
+ logger.debug(f"WebSocket request headers: {dict(request.headers)}")
388
+
389
+ # Prepare a WebSocket response with appropriate settings
390
+ ws = web.WebSocketResponse(
391
+ max_msg_size=1024*1024*10, # 10MB max message size
392
+ timeout=60.0,
393
+ heartbeat=30.0 # Add heartbeat to keep connection alive
394
+ )
395
+
396
+ # Check if WebSocket protocol is supported
397
+ if not ws.can_prepare(request):
398
+ logger.error("Cannot prepare WebSocket: WebSocket protocol not supported")
399
+ return web.Response(status=400, text="WebSocket protocol not supported")
400
+
401
+ try:
402
+ logger.info("Preparing WebSocket connection...")
403
+ await ws.prepare(request)
404
+
405
+ # Generate a unique user ID for this connection
406
+ user_id = str(uuid.uuid4())
407
+
408
+ # Get client IP address
409
+ peername = request.transport.get_extra_info('peername')
410
+ if peername is not None:
411
+ client_ip = peername[0]
412
+ else:
413
+ client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
414
+
415
+ # Log connection success
416
+ logger.info(f"Client {user_id} connecting from IP: {client_ip} - WebSocket connection established")
417
+
418
+ # Mark that the session is established
419
+ is_session_created = False
420
+
421
+ try:
422
+ # Store the user ID in the websocket for easy access
423
+ ws.user_id = user_id
424
+
425
+ # Create a new session for this user
426
+ logger.info(f"Creating game session for user {user_id}")
427
+ user_session = await game_manager.create_session(user_id, ws)
428
+ is_session_created = True
429
+ logger.info(f"Game session created for user {user_id}")
430
+ except Exception as session_error:
431
+ logger.error(f"Error creating game session: {str(session_error)}", exc_info=True)
432
+ if not ws.closed:
433
+ await ws.close(code=1011, message=f"Server error: {str(session_error)}".encode())
434
+ if is_session_created:
435
+ await game_manager.delete_session(user_id)
436
+ return ws
437
+ except Exception as e:
438
+ logger.error(f"Error establishing WebSocket connection: {str(e)}", exc_info=True)
439
+ if not ws.closed and ws.prepared:
440
+ await ws.close(code=1011, message=f"Server error: {str(e)}".encode())
441
+ return ws
442
+
443
+ # Send initial welcome message
444
+ try:
445
+ await ws.send_json({
446
+ 'action': 'welcome',
447
+ 'userId': user_id,
448
+ 'message': 'Welcome to the MatrixGame WebSocket server!',
449
+ 'scenes': game_manager.valid_scenes
450
+ })
451
+ logger.info(f"Sent welcome message to user {user_id}")
452
+ except Exception as welcome_error:
453
+ logger.error(f"Error sending welcome message: {str(welcome_error)}")
454
+ if not ws.closed:
455
+ await ws.close(code=1011, message=b"Failed to send welcome message")
456
+ await game_manager.delete_session(user_id)
457
+ return ws
458
+
459
+ try:
460
+ async for msg in ws:
461
+ if msg.type == WSMsgType.TEXT:
462
+ try:
463
+ data = json.loads(msg.data)
464
+ action = data.get('action')
465
+
466
+ logger.debug(f"Received {action} message from user {user_id}")
467
+
468
+ if action == 'ping':
469
+ # Respond to ping immediately
470
+ await ws.send_json({
471
+ 'action': 'pong',
472
+ 'requestId': data.get('requestId'),
473
+ 'timestamp': time.time()
474
+ })
475
+ else:
476
+ # Route game actions to the session's action queue
477
+ await user_session.action_queue.put(data)
478
+
479
+ except json.JSONDecodeError:
480
+ logger.error(f"Invalid JSON from user {user_id}: {msg.data}")
481
+ if not ws.closed:
482
+ await ws.send_json({
483
+ 'error': 'Invalid JSON message',
484
+ 'success': False
485
+ })
486
+ except Exception as e:
487
+ logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
488
+ if not ws.closed:
489
+ await ws.send_json({
490
+ 'action': data.get('action') if 'data' in locals() else 'unknown',
491
+ 'success': False,
492
+ 'error': f'Error processing message: {str(e)}'
493
+ })
494
+
495
+ elif msg.type == WSMsgType.ERROR:
496
+ logger.error(f"WebSocket error for user {user_id}: {ws.exception()}")
497
+ break
498
+
499
+ elif msg.type == WSMsgType.CLOSE:
500
+ logger.info(f"WebSocket close received for user {user_id} (code: {msg.data}, message: {msg.extra})")
501
+ break
502
+
503
+ elif msg.type == WSMsgType.CLOSING:
504
+ logger.info(f"WebSocket closing for user {user_id}")
505
+ break
506
+
507
+ elif msg.type == WSMsgType.CLOSED:
508
+ logger.info(f"WebSocket already closed for user {user_id}")
509
+ break
510
+
511
+ except Exception as ws_error:
512
+ logger.error(f"Unexpected WebSocket error for user {user_id}: {str(ws_error)}", exc_info=True)
513
+ finally:
514
+ # Cleanup session
515
+ try:
516
+ logger.info(f"Cleaning up session for user {user_id}")
517
+ await game_manager.delete_session(user_id)
518
+ logger.info(f"Connection closed for user {user_id}")
519
+ except Exception as cleanup_error:
520
+ logger.error(f"Error during session cleanup for user {user_id}: {str(cleanup_error)}")
521
+
522
+ return ws
523
+
524
+ async def init_app(args, base_path="") -> web.Application:
525
+ """Initialize the web application"""
526
+ global game_manager
527
+
528
+ # Initialize game manager with command line args
529
+ game_manager = GameManager(args)
530
+
531
+ app = web.Application(
532
+ client_max_size=1024**2*10 # 10MB max size
533
+ )
534
+
535
+ # Add cleanup logic
536
+ async def cleanup(app):
537
+ logger.info("Shutting down server, closing all sessions...")
538
+ await game_manager.close_all_sessions()
539
+
540
+ app.on_shutdown.append(cleanup)
541
+
542
+ # Add routes with CORS headers for WebSockets
543
+ # Configure CORS for all routes
544
+ @web.middleware
545
+ async def cors_middleware(request, handler):
546
+ if request.method == 'OPTIONS':
547
+ # Handle preflight requests
548
+ resp = web.Response()
549
+ resp.headers['Access-Control-Allow-Origin'] = '*'
550
+ resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
551
+ resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
552
+ return resp
553
+
554
+ # Normal request, call the handler
555
+ resp = await handler(request)
556
+
557
+ # Add CORS headers to the response
558
+ resp.headers['Access-Control-Allow-Origin'] = '*'
559
+ resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
560
+ resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
561
+ return resp
562
+
563
+ app.middlewares.append(cors_middleware)
564
+
565
+ # Add a debug endpoint to help diagnose WebSocket issues
566
+ async def debug_handler(request):
567
+ client_ip = request.remote
568
+ headers = dict(request.headers)
569
+ server_host = request.host
570
+
571
+ debug_info = {
572
+ "client_ip": client_ip,
573
+ "server_host": server_host,
574
+ "headers": headers,
575
+ "request_path": request.path,
576
+ "server_time": time.time(),
577
+ "base_path": base_path,
578
+ "websocket_route": f"{base_path}/ws",
579
+ "all_routes": [route.name for route in app.router.routes() if route.name],
580
+ "server_info": {
581
+ "active_sessions": game_manager.session_count,
582
+ "available_scenes": game_manager.valid_scenes
583
+ }
584
+ }
585
+
586
+ return web.json_response(debug_info)
587
+
588
+ # Set up routes with the base_path
589
+ # Add multiple WebSocket routes to ensure compatibility
590
+ logger.info(f"Setting up WebSocket route at {base_path}/ws")
591
+ app.router.add_get(f'{base_path}/ws', websocket_handler, name='ws_handler')
592
+
593
+ # Also add WebSocket route at the root for Hugging Face compatibility
594
+ if base_path:
595
+ logger.info(f"Adding additional WebSocket route at /ws")
596
+ app.router.add_get('/ws', websocket_handler, name='ws_root_handler')
597
+
598
+ # Add routes for API and debug endpoints
599
+ app.router.add_get(f'{base_path}/api/status', status_handler, name='status_handler')
600
+ app.router.add_get(f'{base_path}/api/debug', debug_handler, name='debug_handler')
601
+
602
+ # Serve the client at both the base path and root path for compatibility
603
+ app.router.add_get(f'{base_path}/', root_handler, name='root_handler')
604
+
605
+ # Always serve at the root path for Hugging Face Spaces compatibility
606
+ if base_path:
607
+ app.router.add_get('/', root_handler, name='root_handler_no_base')
608
+
609
+ # Set up static file serving for the client assets
610
+ app.router.add_static(f'{base_path}/assets', pathlib.Path(__file__).parent / 'client', name='static_handler')
611
+
612
+ # Add static file serving at root for compatibility
613
+ if base_path:
614
+ app.router.add_static('/assets', pathlib.Path(__file__).parent / 'client', name='static_handler_no_base')
615
+
616
+ return app
617
+
618
+ def parse_args() -> argparse.Namespace:
619
+ """Parse server-specific command line arguments"""
620
+ parser = argparse.ArgumentParser(description="MatrixGame WebSocket Server")
621
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to")
622
+ parser.add_argument("--port", type=int, default=8080, help="Port to listen on")
623
+ parser.add_argument("--path", type=str, default="", help="Base path for the server (for proxy setups)")
624
+
625
+ # Parse server args first
626
+ server_args, remaining_args = parser.parse_known_args()
627
+
628
+ # Parse model args and combine
629
+ model_args = parse_model_args()
630
+
631
+ # Combine all args
632
+ combined_args = argparse.Namespace(**vars(server_args), **vars(model_args))
633
+
634
+ return combined_args
635
+
636
+ if __name__ == '__main__':
637
+ # Configure GPU environment
638
+ setup_gpu_environment()
639
+
640
+ # Parse command line arguments
641
+ args = parse_args()
642
+
643
+ # Initialize app
644
+ loop = asyncio.get_event_loop()
645
+ app = loop.run_until_complete(init_app(args, base_path=args.path))
646
+
647
+ # Start server
648
+ logger.info(f"Starting MatrixGame WebSocket Server at {args.host}:{args.port}")
649
+ web.run_app(app, host=args.host, port=args.port)
example/utils.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ MatrixGame Utility Functions
6
+
7
+ This module contains helper functions and utilities for the MatrixGame project.
8
+ """
9
+
10
+ import os
11
+ import logging
12
+ import argparse
13
+ import torch
14
+ import numpy as np
15
+ import cv2
16
+ from PIL import Image
17
+ from typing import Dict, List, Tuple, Any, Optional, Union
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ def setup_gpu_environment():
27
+ """
28
+ Configure the GPU environment and log GPU information.
29
+
30
+ Returns:
31
+ bool: True if CUDA is available, False otherwise
32
+ """
33
+ # Set CUDA memory allocation environment variable for better performance
34
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
35
+
36
+ # Check if CUDA is available and log information
37
+ if torch.cuda.is_available():
38
+ gpu_count = torch.cuda.device_count()
39
+ gpu_info = []
40
+
41
+ for i in range(gpu_count):
42
+ gpu_name = torch.cuda.get_device_name(i)
43
+ gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3) # Convert to GB
44
+ gpu_info.append(f"GPU {i}: {gpu_name} ({gpu_memory:.2f} GB)")
45
+
46
+ logger.info(f"CUDA is available. Found {gpu_count} GPU(s):")
47
+ for info in gpu_info:
48
+ logger.info(f" {info}")
49
+ return True
50
+ else:
51
+ logger.warning("CUDA is not available. Running in CPU-only mode.")
52
+ return False
53
+
54
+ def parse_model_args() -> argparse.Namespace:
55
+ """
56
+ Parse command line arguments for model paths and configuration.
57
+
58
+ Returns:
59
+ argparse.Namespace: Parsed arguments
60
+ """
61
+ parser = argparse.ArgumentParser(description="MatrixGame Model Configuration")
62
+
63
+ # Model paths
64
+ parser.add_argument("--model_root", type=str, default="./models/matrixgame",
65
+ help="Root directory for model files")
66
+ parser.add_argument("--dit_path", type=str, default=None,
67
+ help="Path to DIT model. If not provided, will use MODEL_ROOT/dit/")
68
+ parser.add_argument("--vae_path", type=str, default=None,
69
+ help="Path to VAE model. If not provided, will use MODEL_ROOT/vae/")
70
+ parser.add_argument("--textenc_path", type=str, default=None,
71
+ help="Path to text encoder model. If not provided, will use MODEL_ROOT")
72
+
73
+ # Model settings
74
+ parser.add_argument("--inference_steps", type=int, default=20,
75
+ help="Number of inference steps for frame generation (lower is faster)")
76
+ parser.add_argument("--guidance_scale", type=float, default=6.0,
77
+ help="Guidance scale for generation")
78
+ parser.add_argument("--frame_width", type=int, default=640,
79
+ help="Width of the generated frames")
80
+ parser.add_argument("--frame_height", type=int, default=360,
81
+ help="Height of the generated frames")
82
+ parser.add_argument("--num_pre_frames", type=int, default=3,
83
+ help="Number of pre-frames for conditioning")
84
+ parser.add_argument("--fps", type=int, default=16,
85
+ help="Frames per second for video")
86
+
87
+ args = parser.parse_args()
88
+
89
+ # Set environment variables for model paths if provided
90
+ if args.model_root:
91
+ os.environ.setdefault("MODEL_ROOT", args.model_root)
92
+ if args.dit_path:
93
+ os.environ.setdefault("DIT_PATH", args.dit_path)
94
+ else:
95
+ os.environ.setdefault("DIT_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "dit/"))
96
+ if args.vae_path:
97
+ os.environ.setdefault("VAE_PATH", args.vae_path)
98
+ else:
99
+ os.environ.setdefault("VAE_PATH", os.path.join(os.environ.get("MODEL_ROOT", "./models/matrixgame"), "vae/"))
100
+ if args.textenc_path:
101
+ os.environ.setdefault("TEXTENC_PATH", args.textenc_path)
102
+ else:
103
+ os.environ.setdefault("TEXTENC_PATH", os.environ.get("MODEL_ROOT", "./models/matrixgame"))
104
+
105
+ return args
106
+
107
+ def visualize_controls(frame: np.ndarray, keyboard_condition: List, mouse_condition: List,
108
+ frame_width: int, frame_height: int) -> np.ndarray:
109
+ """
110
+ Visualize keyboard and mouse controls on the frame.
111
+
112
+ Args:
113
+ frame: The video frame to visualize on
114
+ keyboard_condition: Keyboard state as a list
115
+ mouse_condition: Mouse state as a list
116
+ frame_width: Width of the frame
117
+ frame_height: Height of the frame
118
+
119
+ Returns:
120
+ np.ndarray: Frame with visualized controls
121
+ """
122
+ # Clone the frame to avoid modifying the original
123
+ frame = frame.copy()
124
+
125
+ # If we have keyboard/mouse conditions, visualize them on the frame
126
+ if keyboard_condition:
127
+ # Visualize keyboard inputs
128
+ keys = ["W", "S", "A", "D", "JUMP", "ATTACK"]
129
+ for i, key_pressed in enumerate(keyboard_condition[0]):
130
+ color = (0, 255, 0) if key_pressed else (100, 100, 100)
131
+ cv2.putText(frame, keys[i], (20 + i*100, 30),
132
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
133
+
134
+ if mouse_condition:
135
+ # Visualize mouse movement
136
+ mouse_x, mouse_y = mouse_condition[0]
137
+ # Scale mouse values for visualization
138
+ offset_x = int(mouse_x * 100)
139
+ offset_y = int(mouse_y * 100)
140
+ center_x, center_y = frame_width // 2, frame_height // 2
141
+ cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1)
142
+ cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}",
143
+ (frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
144
+
145
+ return frame
146
+
147
+ def frame_to_jpeg(frame: np.ndarray, frame_height: int, frame_width: int) -> bytes:
148
+ """
149
+ Convert a frame to JPEG bytes.
150
+
151
+ Args:
152
+ frame: The video frame to convert
153
+ frame_height: Height of the frame for fallback
154
+ frame_width: Width of the frame for fallback
155
+
156
+ Returns:
157
+ bytes: JPEG bytes of the frame
158
+ """
159
+ success, buffer = cv2.imencode('.jpg', frame)
160
+ if not success:
161
+ logger.error("Failed to encode frame as JPEG")
162
+ # Return a blank frame
163
+ blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 100
164
+ success, buffer = cv2.imencode('.jpg', blank)
165
+
166
+ return buffer.tobytes()
167
+
168
+ def load_scene_frames(scene_name: str, frame_width: int, frame_height: int) -> List[np.ndarray]:
169
+ """
170
+ Load initial frames for a scene from asset directory.
171
+
172
+ Args:
173
+ scene_name: Name of the scene
174
+ frame_width: Width to resize frames to
175
+ frame_height: Height to resize frames to
176
+
177
+ Returns:
178
+ List[np.ndarray]: List of frames as numpy arrays
179
+ """
180
+ frames = []
181
+ scene_dir = f"./GameWorldScore/asset/init_image/{scene_name}"
182
+
183
+ if os.path.exists(scene_dir):
184
+ image_files = sorted([f for f in os.listdir(scene_dir) if f.endswith('.png') or f.endswith('.jpg')])
185
+ for img_file in image_files:
186
+ try:
187
+ img_path = os.path.join(scene_dir, img_file)
188
+ img = Image.open(img_path).convert("RGB")
189
+ img = img.resize((frame_width, frame_height))
190
+ frames.append(np.array(img))
191
+ except Exception as e:
192
+ logger.error(f"Error loading image {img_file}: {str(e)}")
193
+
194
+ # If no frames were loaded, create a default colored frame with text
195
+ if not frames:
196
+ frame = np.ones((frame_height, frame_height, 3), dtype=np.uint8) * 100
197
+ # Add scene name as text
198
+ cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
199
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
200
+ frames.append(frame)
201
+
202
+ return frames
game/spawn/1/act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7caabf75f45d4c8bae5c0b66dc2b5a3cbf3ab7dbf89521d6ba539c4f30048d75
3
+ size 10688
game/spawn/1/full_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c186e607a8cc4922e17ee66c1f37dc4858adfef220b6bf48fbcda9bf75ffde34
3
+ size 22260128
game/spawn/1/low_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08e31f7ef447a1dffda3b51b4102ae4301f5804145b63c71acc5c278a294b1ee
3
+ size 368768
game/spawn/1/next_act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0f2d1c96337459ddd84b9c1a8dbad9eb7284cb813a2f0ebd9cb4d757dd294e1
3
+ size 105728
game/spawn/2/act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:619cdda3de7f55e48a753b64542354387294c2688a6e014af85b094996b8a486
3
+ size 10688
game/spawn/2/full_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:911389a5957acd8d7b96fccb1917ee4a4b0c74f0b9420f61580b571965dd99ff
3
+ size 22260128
game/spawn/2/low_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8351be59118b4c2237800a9119937d472564e63c70a8d1911bc2d52ac3a95a2
3
+ size 368768
game/spawn/2/next_act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba1b85fd4548f805352306a04e24368718328bfd513cb95305d0f9284fe9719f
3
+ size 105728
game/spawn/3/act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60ff3ff40a48105e33feae08979dd4d7d7570984e10c950aae9308c08841400a
3
+ size 10688
game/spawn/3/full_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1f8a9a266e05f8d8a741dadb697bf3ed89c1a166dab443c1d81091c3cf1824b
3
+ size 22260128
game/spawn/3/low_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0dd822892eb5201e23d5298da69523a4d43d740e37d7647d30a288a5b440991e
3
+ size 368768
game/spawn/3/next_act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40f2a5d1381a93fc46e58f7ac0bf5df9f507c5c3b6489e85a0896deba8da9dca
3
+ size 105728
index.html ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>AI Game Multiverse</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ margin: 0;
11
+ padding: 0;
12
+ background-color: #121212;
13
+ color: #e0e0e0;
14
+ display: flex;
15
+ flex-direction: column;
16
+ align-items: center;
17
+ user-select: none;
18
+ -webkit-user-select: none;
19
+ -moz-user-select: none;
20
+ -ms-user-select: none;
21
+ overflow-x: hidden;
22
+ }
23
+
24
+ .container {
25
+ width: 100%;
26
+ max-width: 100%;
27
+ display: flex;
28
+ flex-direction: column;
29
+ align-items: center;
30
+ }
31
+
32
+ .game-area {
33
+ display: flex;
34
+ flex-direction: column;
35
+ align-items: center;
36
+ width: 100%;
37
+ max-height: 85vh;
38
+ margin: 0;
39
+ position: relative;
40
+ }
41
+
42
+ #mouse-tracking-area {
43
+ position: relative;
44
+ width: 100%;
45
+ height: auto;
46
+ cursor: pointer;
47
+ display: flex;
48
+ justify-content: center;
49
+ align-items: center;
50
+ max-height: 85vh;
51
+ }
52
+
53
+ #game-canvas {
54
+ width: 100%;
55
+ height: auto;
56
+ max-height: 85vh;
57
+ object-fit: contain;
58
+ background-color: #000;
59
+ pointer-events: none;
60
+ -webkit-user-drag: none;
61
+ -khtml-user-drag: none;
62
+ -moz-user-drag: none;
63
+ -o-user-drag: none;
64
+ user-drag: none;
65
+ }
66
+
67
+ .controls {
68
+ display: flex;
69
+ justify-content: space-between;
70
+ width: 100%;
71
+ max-width: 1200px;
72
+ padding: 10px;
73
+ background-color: rgba(0, 0, 0, 0.5);
74
+ position: absolute;
75
+ bottom: 0;
76
+ z-index: 10;
77
+ box-sizing: border-box;
78
+ }
79
+
80
+ .panels-container {
81
+ display: flex;
82
+ width: 100%;
83
+ max-width: 1200px;
84
+ margin: 10px auto;
85
+ gap: 10px;
86
+ }
87
+
88
+ .panel {
89
+ flex: 1;
90
+ background-color: #1E1E1E;
91
+ border-radius: 5px;
92
+ overflow: hidden;
93
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
94
+ transition: height 0.3s ease;
95
+ }
96
+
97
+ .panel-header {
98
+ background-color: #272727;
99
+ padding: 10px 15px;
100
+ display: flex;
101
+ justify-content: space-between;
102
+ align-items: center;
103
+ cursor: pointer;
104
+ }
105
+
106
+ .panel-title {
107
+ font-weight: bold;
108
+ color: #4CAF50;
109
+ }
110
+
111
+ .toggle-button {
112
+ background: none;
113
+ border: none;
114
+ color: #e0e0e0;
115
+ font-size: 18px;
116
+ cursor: pointer;
117
+ }
118
+
119
+ .toggle-button:focus {
120
+ outline: none;
121
+ }
122
+
123
+ .panel-content {
124
+ padding: 15px;
125
+ max-height: 300px;
126
+ overflow-y: auto;
127
+ transition: all 0.3s ease;
128
+ }
129
+
130
+ .collapsed .panel-content {
131
+ max-height: 0;
132
+ padding-top: 0;
133
+ padding-bottom: 0;
134
+ overflow: hidden;
135
+ }
136
+
137
+ button {
138
+ background-color: #4CAF50;
139
+ color: white;
140
+ border: none;
141
+ padding: 10px 15px;
142
+ text-align: center;
143
+ text-decoration: none;
144
+ display: inline-block;
145
+ font-size: 14px;
146
+ border-radius: 5px;
147
+ cursor: pointer;
148
+ margin: 5px;
149
+ transition: background-color 0.3s;
150
+ }
151
+
152
+ button:hover {
153
+ background-color: #45a049;
154
+ }
155
+
156
+ button:disabled {
157
+ background-color: #cccccc;
158
+ cursor: not-allowed;
159
+ }
160
+
161
+ select {
162
+ padding: 10px;
163
+ border-radius: 5px;
164
+ background-color: #2A2A2A;
165
+ color: #e0e0e0;
166
+ border: 1px solid #4CAF50;
167
+ }
168
+
169
+ .status {
170
+ margin-top: 10px;
171
+ color: #4CAF50;
172
+ }
173
+
174
+ .key-indicators {
175
+ display: flex;
176
+ justify-content: center;
177
+ margin-top: 15px;
178
+ }
179
+
180
+ .key {
181
+ width: 40px;
182
+ height: 40px;
183
+ margin: 0 5px;
184
+ background-color: #2A2A2A;
185
+ border: 1px solid #444;
186
+ border-radius: 5px;
187
+ display: flex;
188
+ justify-content: center;
189
+ align-items: center;
190
+ font-weight: bold;
191
+ transition: background-color 0.2s;
192
+ }
193
+
194
+ .key.active {
195
+ background-color: #4CAF50;
196
+ color: white;
197
+ }
198
+
199
+ .key-row {
200
+ display: flex;
201
+ justify-content: center;
202
+ margin: 5px 0;
203
+ }
204
+
205
+ .spacebar {
206
+ width: 150px;
207
+ }
208
+
209
+ .connection-info {
210
+ font-family: monospace;
211
+ height: 100%;
212
+ overflow-y: auto;
213
+ }
214
+
215
+ .log-entry {
216
+ margin: 5px 0;
217
+ padding: 3px;
218
+ border-bottom: 1px solid #333;
219
+ }
220
+
221
+ .fps-counter {
222
+ position: absolute;
223
+ top: 10px;
224
+ right: 10px;
225
+ background-color: rgba(0,0,0,0.5);
226
+ color: #4CAF50;
227
+ padding: 5px;
228
+ border-radius: 3px;
229
+ font-family: monospace;
230
+ z-index: 20;
231
+ }
232
+
233
+ #mouse-position {
234
+ position: absolute;
235
+ top: 10px;
236
+ left: 10px;
237
+ background-color: rgba(0,0,0,0.5);
238
+ color: #4CAF50;
239
+ padding: 5px;
240
+ border-radius: 3px;
241
+ font-family: monospace;
242
+ z-index: 20;
243
+ }
244
+
245
+ @media (max-width: 768px) {
246
+ .panels-container {
247
+ flex-direction: column;
248
+ }
249
+ }
250
+
251
+ .header {
252
+ text-align: center;
253
+ padding: 15px;
254
+ margin-bottom: 20px;
255
+ }
256
+
257
+ .header h1 {
258
+ margin: 0;
259
+ color: #4CAF50;
260
+ font-size: 2rem;
261
+ }
262
+
263
+ .header p {
264
+ margin-top: 5px;
265
+ color: #aaa;
266
+ }
267
+ </style>
268
+ </head>
269
+ <body>
270
+ <div class="header">
271
+ <h1>AI Game Multiverse</h1>
272
+ <p>Play procedurally generated games using AI</p>
273
+ </div>
274
+
275
+ <div class="container">
276
+ <div class="game-area">
277
+ <div id="mouse-tracking-area">
278
+ <img id="game-canvas" src="" alt="Game Frame">
279
+ <div id="mouse-position">Mouse: 0.00, 0.00</div>
280
+ <div class="fps-counter" id="fps-counter">FPS: 0</div>
281
+ </div>
282
+
283
+ <div class="controls">
284
+ <button id="connect-btn">Connect</button>
285
+ <button id="start-stream-btn" disabled>Start Stream</button>
286
+ <button id="stop-stream-btn" disabled>Stop Stream</button>
287
+ <select id="scene-select" disabled>
288
+ <option value="forest">Forest</option>
289
+ <option value="desert">Desert</option>
290
+ <option value="beach">Beach</option>
291
+ <option value="hills">Hills</option>
292
+ <option value="river">River</option>
293
+ <option value="plain">Plain</option>
294
+ </select>
295
+ </div>
296
+ </div>
297
+
298
+ <div class="panels-container">
299
+ <!-- Controls Panel -->
300
+ <div class="panel" id="controls-panel">
301
+ <div class="panel-header" onclick="togglePanel('controls-panel')">
302
+ <div class="panel-title">Keyboard Controls</div>
303
+ <button class="toggle-button">−</button>
304
+ </div>
305
+ <div class="panel-content">
306
+ <div class="key-indicators">
307
+ <div class="key-row">
308
+ <div id="key-w" class="key">W</div>
309
+ </div>
310
+ <div class="key-row">
311
+ <div id="key-a" class="key">A</div>
312
+ <div id="key-s" class="key">S</div>
313
+ <div id="key-d" class="key">D</div>
314
+ </div>
315
+ <div class="key-row">
316
+ <div id="key-space" class="key spacebar">SPACE</div>
317
+ </div>
318
+ <div class="key-row">
319
+ <div id="key-shift" class="key">SHIFT</div>
320
+ </div>
321
+ </div>
322
+ <p class="status">
323
+ W or ↑ = Forward, S or ↓ = Back, A or ← = Left, D or → = Right<br>
324
+ Space = Jump, Shift = Attack<br>
325
+ Click on game view to capture mouse (ESC to release)<br>
326
+ Mouse = Look around
327
+ </p>
328
+ </div>
329
+ </div>
330
+
331
+ <!-- Connection Log Panel -->
332
+ <div class="panel" id="log-panel">
333
+ <div class="panel-header" onclick="togglePanel('log-panel')">
334
+ <div class="panel-title">Connection Log</div>
335
+ <button class="toggle-button">−</button>
336
+ </div>
337
+ <div class="panel-content">
338
+ <div class="connection-info" id="connection-log">
339
+ <div class="log-entry">Welcome to AI Game Multiverse. Click Connect to begin.</div>
340
+ </div>
341
+ </div>
342
+ </div>
343
+ </div>
344
+ </div>
345
+
346
+ <script>
347
+ // WebSocket connection
348
+ let socket = null;
349
+ let userId = null;
350
+ let isStreaming = false;
351
+ let lastFrameTime = 0;
352
+ let frameCount = 0;
353
+ let fpsUpdateInterval = null;
354
+
355
+ // DOM Elements
356
+ const connectBtn = document.getElementById('connect-btn');
357
+ const startStreamBtn = document.getElementById('start-stream-btn');
358
+ const stopStreamBtn = document.getElementById('stop-stream-btn');
359
+ const sceneSelect = document.getElementById('scene-select');
360
+ const gameCanvas = document.getElementById('game-canvas');
361
+ const connectionLog = document.getElementById('connection-log');
362
+ const mousePosition = document.getElementById('mouse-position');
363
+ const fpsCounter = document.getElementById('fps-counter');
364
+ const mouseTrackingArea = document.getElementById('mouse-tracking-area');
365
+
366
+ // Pointer Lock API support check
367
+ const pointerLockSupported = 'pointerLockElement' in document ||
368
+ 'mozPointerLockElement' in document ||
369
+ 'webkitPointerLockElement' in document;
370
+
371
+ // Keyboard DOM elements
372
+ const keyElements = {
373
+ 'w': document.getElementById('key-w'),
374
+ 'a': document.getElementById('key-a'),
375
+ 's': document.getElementById('key-s'),
376
+ 'd': document.getElementById('key-d'),
377
+ 'space': document.getElementById('key-space'),
378
+ 'shift': document.getElementById('key-shift')
379
+ };
380
+
381
+ // Key mapping to action names
382
+ const keyToAction = {
383
+ 'w': 'forward',
384
+ 'arrowup': 'forward',
385
+ 'a': 'left',
386
+ 'arrowleft': 'left',
387
+ 's': 'back',
388
+ 'arrowdown': 'back',
389
+ 'd': 'right',
390
+ 'arrowright': 'right',
391
+ ' ': 'jump',
392
+ 'shift': 'attack'
393
+ };
394
+
395
+ // Key state tracking
396
+ const keyState = {
397
+ 'forward': false,
398
+ 'back': false,
399
+ 'left': false,
400
+ 'right': false,
401
+ 'jump': false,
402
+ 'attack': false
403
+ };
404
+
405
+ // Mouse state
406
+ const mouseState = {
407
+ x: 0,
408
+ y: 0,
409
+ captured: false
410
+ };
411
+
412
+ // Test server connectivity before establishing WebSocket
413
+ async function testServerConnectivity() {
414
+ try {
415
+ // Get base path by extracting path from the URL
416
+ const basePath = window.location.pathname.replace(/\/+$/, '');
417
+
418
+ // Try to fetch the debug endpoint to see if the server is accessible
419
+ const response = await fetch(`${window.location.protocol}//${window.location.host}${basePath}/api/debug`);
420
+ if (!response.ok) {
421
+ throw new Error(`Server returned ${response.status}`);
422
+ }
423
+
424
+ const debugInfo = await response.json();
425
+ logMessage(`Server connection test successful! Server time: ${new Date(debugInfo.server_time * 1000).toLocaleTimeString()}`);
426
+
427
+ // Log available routes from server
428
+ if (debugInfo.all_routes && debugInfo.all_routes.length > 0) {
429
+ logMessage(`Available routes: ${debugInfo.all_routes.join(', ')}`);
430
+ }
431
+
432
+ // Return the debug info for connection setup
433
+ return debugInfo;
434
+ } catch (error) {
435
+ logMessage(`Server connection test failed: ${error.message}`);
436
+ return null;
437
+ }
438
+ }
439
+
440
+ // Connect to WebSocket server
441
+ async function connectWebSocket() {
442
+ // First test connectivity to the server
443
+ logMessage('Testing server connectivity...');
444
+ const debugInfo = await testServerConnectivity();
445
+
446
+ // Use secure WebSocket (wss://) if the page is loaded over HTTPS
447
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
448
+
449
+ // Get base path from URL
450
+ const basePath = window.location.pathname.replace(/\/+$/, '');
451
+
452
+ // Try both with and without base path for WebSocket connection
453
+ let serverUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}${basePath}/ws`;
454
+ logMessage(`Attempting to connect to WebSocket at ${serverUrl}...`);
455
+
456
+ // For compatibility, try the direct /ws path if the base path doesn't work
457
+ const fallbackUrl = `${protocol}//${window.location.hostname}${window.location.port ? ':' + window.location.port : ''}/ws`;
458
+
459
+ try {
460
+ socket = new WebSocket(serverUrl);
461
+ setupWebSocketHandlers();
462
+
463
+ // Set a timeout to try the fallback URL if the first one doesn't connect
464
+ setTimeout(() => {
465
+ if (socket.readyState !== WebSocket.OPEN && socket.readyState !== WebSocket.CONNECTING) {
466
+ logMessage(`Connection to ${serverUrl} failed. Trying fallback URL: ${fallbackUrl}`);
467
+ socket = new WebSocket(fallbackUrl);
468
+ setupWebSocketHandlers();
469
+ }
470
+ }, 3000);
471
+ } catch (error) {
472
+ logMessage(`Error connecting to WebSocket: ${error.message}`);
473
+ resetUI();
474
+ }
475
+ }
476
+
477
+ // Set up WebSocket event handlers
478
+ function setupWebSocketHandlers() {
479
+ socket.onopen = () => {
480
+ logMessage('WebSocket connection established');
481
+ connectBtn.textContent = 'Disconnect';
482
+ startStreamBtn.disabled = false;
483
+ sceneSelect.disabled = false;
484
+ };
485
+
486
+ socket.onmessage = (event) => {
487
+ const message = JSON.parse(event.data);
488
+
489
+ switch (message.action) {
490
+ case 'welcome':
491
+ userId = message.userId;
492
+ logMessage(`Connected with user ID: ${userId}`);
493
+
494
+ // Update scene options if server provides them
495
+ if (message.scenes && Array.isArray(message.scenes)) {
496
+ sceneSelect.innerHTML = '';
497
+ message.scenes.forEach(scene => {
498
+ const option = document.createElement('option');
499
+ option.value = scene;
500
+ option.textContent = scene.charAt(0).toUpperCase() + scene.slice(1);
501
+ sceneSelect.appendChild(option);
502
+ });
503
+ }
504
+ break;
505
+
506
+ case 'frame':
507
+ // Process incoming frame
508
+ processFrame(message);
509
+ break;
510
+
511
+ case 'start_stream':
512
+ if (message.success) {
513
+ isStreaming = true;
514
+ startStreamBtn.disabled = true;
515
+ stopStreamBtn.disabled = false;
516
+ logMessage(`Streaming started: ${message.message}`);
517
+
518
+ // Start FPS counter
519
+ startFpsCounter();
520
+ } else {
521
+ logMessage(`Error starting stream: ${message.error}`);
522
+ }
523
+ break;
524
+
525
+ case 'stop_stream':
526
+ if (message.success) {
527
+ isStreaming = false;
528
+ startStreamBtn.disabled = false;
529
+ stopStreamBtn.disabled = true;
530
+ logMessage('Streaming stopped');
531
+
532
+ // Stop FPS counter
533
+ stopFpsCounter();
534
+ } else {
535
+ logMessage(`Error stopping stream: ${message.error}`);
536
+ }
537
+ break;
538
+
539
+ case 'pong':
540
+ // Server responded to ping
541
+ break;
542
+
543
+ case 'change_scene':
544
+ if (message.success) {
545
+ logMessage(`Scene changed to ${message.scene}`);
546
+ } else {
547
+ logMessage(`Error changing scene: ${message.error}`);
548
+ }
549
+ break;
550
+
551
+ default:
552
+ logMessage(`Received message: ${JSON.stringify(message)}`);
553
+ }
554
+ };
555
+
556
+ socket.onclose = (event) => {
557
+ logMessage(`WebSocket connection closed (code: ${event.code}, reason: ${event.reason || 'none given'})`);
558
+ resetUI();
559
+ };
560
+
561
+ socket.onerror = (error) => {
562
+ logMessage(`WebSocket error. This is often caused by CORS issues or the server being inaccessible.`);
563
+ console.error('WebSocket error:', error);
564
+ resetUI();
565
+ };
566
+ }
567
+
568
+ // Disconnect from WebSocket server
569
+ function disconnectWebSocket() {
570
+ if (socket && socket.readyState === WebSocket.OPEN) {
571
+ // Stop streaming if active
572
+ if (isStreaming) {
573
+ sendStopStream();
574
+ }
575
+
576
+ // Close the socket
577
+ socket.close();
578
+ logMessage('Disconnected from server');
579
+ }
580
+ }
581
+
582
+ // Start streaming frames
583
+ function sendStartStream() {
584
+ if (socket && socket.readyState === WebSocket.OPEN) {
585
+ socket.send(JSON.stringify({
586
+ action: 'start_stream',
587
+ requestId: generateRequestId(),
588
+ fps: 16 // Default FPS
589
+ }));
590
+ }
591
+ }
592
+
593
+ // Stop streaming frames
594
+ function sendStopStream() {
595
+ if (socket && socket.readyState === WebSocket.OPEN) {
596
+ socket.send(JSON.stringify({
597
+ action: 'stop_stream',
598
+ requestId: generateRequestId()
599
+ }));
600
+ }
601
+ }
602
+
603
+ // Send keyboard input to server
604
+ function sendKeyboardInput(key, pressed) {
605
+ if (socket && socket.readyState === WebSocket.OPEN) {
606
+ socket.send(JSON.stringify({
607
+ action: 'keyboard_input',
608
+ requestId: generateRequestId(),
609
+ key: key,
610
+ pressed: pressed
611
+ }));
612
+ }
613
+ }
614
+
615
+ // Send mouse input to server
616
+ function sendMouseInput(x, y) {
617
+ if (socket && socket.readyState === WebSocket.OPEN && isStreaming) {
618
+ socket.send(JSON.stringify({
619
+ action: 'mouse_input',
620
+ requestId: generateRequestId(),
621
+ x: x,
622
+ y: y
623
+ }));
624
+ }
625
+ }
626
+
627
+ // Change scene
628
+ function sendChangeScene(scene) {
629
+ if (socket && socket.readyState === WebSocket.OPEN) {
630
+ socket.send(JSON.stringify({
631
+ action: 'change_scene',
632
+ requestId: generateRequestId(),
633
+ scene: scene
634
+ }));
635
+ }
636
+ }
637
+
638
+ // Process incoming frame
639
+ function processFrame(message) {
640
+ // Update FPS calculation
641
+ const now = performance.now();
642
+ if (lastFrameTime > 0) {
643
+ frameCount++;
644
+ }
645
+ lastFrameTime = now;
646
+
647
+ // Update the canvas with the new frame
648
+ if (message.frameData) {
649
+ gameCanvas.src = `data:image/jpeg;base64,${message.frameData}`;
650
+ }
651
+ }
652
+
653
+ // Generate a random request ID
654
+ function generateRequestId() {
655
+ return Math.random().toString(36).substring(2, 15);
656
+ }
657
+
658
+ // Log message to the connection info panel
659
+ function logMessage(message) {
660
+ const logEntry = document.createElement('div');
661
+ logEntry.className = 'log-entry';
662
+
663
+ const timestamp = new Date().toLocaleTimeString();
664
+ logEntry.textContent = `[${timestamp}] ${message}`;
665
+
666
+ connectionLog.appendChild(logEntry);
667
+ connectionLog.scrollTop = connectionLog.scrollHeight;
668
+
669
+ // Limit number of log entries
670
+ while (connectionLog.children.length > 100) {
671
+ connectionLog.removeChild(connectionLog.firstChild);
672
+ }
673
+ }
674
+
675
+ // Start FPS counter updates
676
+ function startFpsCounter() {
677
+ frameCount = 0;
678
+ lastFrameTime = 0;
679
+
680
+ // Update FPS display every second
681
+ fpsUpdateInterval = setInterval(() => {
682
+ fpsCounter.textContent = `FPS: ${frameCount}`;
683
+ frameCount = 0;
684
+ }, 1000);
685
+ }
686
+
687
+ // Stop FPS counter updates
688
+ function stopFpsCounter() {
689
+ if (fpsUpdateInterval) {
690
+ clearInterval(fpsUpdateInterval);
691
+ fpsUpdateInterval = null;
692
+ }
693
+ fpsCounter.textContent = 'FPS: 0';
694
+ }
695
+
696
+ // Reset UI to initial state
697
+ function resetUI() {
698
+ connectBtn.textContent = 'Connect';
699
+ startStreamBtn.disabled = true;
700
+ stopStreamBtn.disabled = true;
701
+ sceneSelect.disabled = true;
702
+
703
+ // Reset key indicators
704
+ for (const key in keyElements) {
705
+ keyElements[key].classList.remove('active');
706
+ }
707
+
708
+ // Stop FPS counter
709
+ stopFpsCounter();
710
+
711
+ // Reset streaming state
712
+ isStreaming = false;
713
+ }
714
+
715
+ // Event Listeners
716
+ connectBtn.addEventListener('click', () => {
717
+ if (socket && socket.readyState === WebSocket.OPEN) {
718
+ disconnectWebSocket();
719
+ } else {
720
+ connectWebSocket();
721
+ }
722
+ });
723
+
724
+ startStreamBtn.addEventListener('click', sendStartStream);
725
+ stopStreamBtn.addEventListener('click', sendStopStream);
726
+
727
+ sceneSelect.addEventListener('change', () => {
728
+ sendChangeScene(sceneSelect.value);
729
+ });
730
+
731
+ // Keyboard event listeners
732
+ document.addEventListener('keydown', (event) => {
733
+ const key = event.key.toLowerCase();
734
+
735
+ // Map key to action
736
+ let action = keyToAction[key];
737
+ if (!action && key === ' ') {
738
+ action = keyToAction[' ']; // Handle spacebar
739
+ }
740
+
741
+ if (action && !keyState[action]) {
742
+ keyState[action] = true;
743
+
744
+ // Update visual indicator
745
+ const keyElement = keyElements[key] ||
746
+ (key === ' ' ? keyElements['space'] : null) ||
747
+ (key === 'shift' ? keyElements['shift'] : null);
748
+
749
+ if (keyElement) {
750
+ keyElement.classList.add('active');
751
+ }
752
+
753
+ // Send to server
754
+ sendKeyboardInput(action, true);
755
+ }
756
+
757
+ // Prevent default actions for game controls
758
+ if (Object.keys(keyToAction).includes(key) || key === ' ') {
759
+ event.preventDefault();
760
+ }
761
+ });
762
+
763
+ document.addEventListener('keyup', (event) => {
764
+ const key = event.key.toLowerCase();
765
+
766
+ // Map key to action
767
+ let action = keyToAction[key];
768
+ if (!action && key === ' ') {
769
+ action = keyToAction[' ']; // Handle spacebar
770
+ }
771
+
772
+ if (action && keyState[action]) {
773
+ keyState[action] = false;
774
+
775
+ // Update visual indicator
776
+ const keyElement = keyElements[key] ||
777
+ (key === ' ' ? keyElements['space'] : null) ||
778
+ (key === 'shift' ? keyElements['shift'] : null);
779
+
780
+ if (keyElement) {
781
+ keyElement.classList.remove('active');
782
+ }
783
+
784
+ // Send to server
785
+ sendKeyboardInput(action, false);
786
+ }
787
+ });
788
+
789
+ // Mouse capture functions
790
+ function requestPointerLock() {
791
+ if (!mouseState.captured && pointerLockSupported) {
792
+ mouseTrackingArea.requestPointerLock = mouseTrackingArea.requestPointerLock ||
793
+ mouseTrackingArea.mozRequestPointerLock ||
794
+ mouseTrackingArea.webkitRequestPointerLock;
795
+ mouseTrackingArea.requestPointerLock();
796
+ logMessage('Mouse captured. Press ESC to release.');
797
+ }
798
+ }
799
+
800
+ function exitPointerLock() {
801
+ if (mouseState.captured) {
802
+ document.exitPointerLock = document.exitPointerLock ||
803
+ document.mozExitPointerLock ||
804
+ document.webkitExitPointerLock;
805
+ document.exitPointerLock();
806
+ logMessage('Mouse released.');
807
+ }
808
+ }
809
+
810
+ // Handle pointer lock change events
811
+ document.addEventListener('pointerlockchange', pointerLockChangeHandler);
812
+ document.addEventListener('mozpointerlockchange', pointerLockChangeHandler);
813
+ document.addEventListener('webkitpointerlockchange', pointerLockChangeHandler);
814
+
815
+ function pointerLockChangeHandler() {
816
+ if (document.pointerLockElement === mouseTrackingArea ||
817
+ document.mozPointerLockElement === mouseTrackingArea ||
818
+ document.webkitPointerLockElement === mouseTrackingArea) {
819
+ // Pointer is locked, enable mouse movement tracking
820
+ mouseState.captured = true;
821
+ document.addEventListener('mousemove', handleMouseMovement);
822
+ } else {
823
+ // Pointer is unlocked, disable mouse movement tracking
824
+ mouseState.captured = false;
825
+ document.removeEventListener('mousemove', handleMouseMovement);
826
+ // Reset mouse state
827
+ mouseState.x = 0;
828
+ mouseState.y = 0;
829
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
830
+ throttledSendMouseInput();
831
+ }
832
+ }
833
+
834
+ // Mouse tracking with pointer lock
835
+ function handleMouseMovement(event) {
836
+ if (mouseState.captured) {
837
+ // Use movement for mouse look when captured
838
+ const sensitivity = 0.005; // Adjust sensitivity
839
+ mouseState.x += event.movementX * sensitivity;
840
+ mouseState.y -= event.movementY * sensitivity; // Invert Y for intuitive camera control
841
+
842
+ // Clamp values
843
+ mouseState.x = Math.max(-1, Math.min(1, mouseState.x));
844
+ mouseState.y = Math.max(-1, Math.min(1, mouseState.y));
845
+
846
+ // Update display
847
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
848
+
849
+ // Send to server (throttled)
850
+ throttledSendMouseInput();
851
+ }
852
+ }
853
+
854
+ // Mouse click to capture
855
+ mouseTrackingArea.addEventListener('click', () => {
856
+ if (!mouseState.captured && isStreaming) {
857
+ requestPointerLock();
858
+ }
859
+ });
860
+
861
+ // Standard mouse tracking for when pointer is not locked
862
+ mouseTrackingArea.addEventListener('mousemove', (event) => {
863
+ if (!mouseState.captured) {
864
+ // Calculate normalized coordinates relative to the center of the tracking area
865
+ const rect = mouseTrackingArea.getBoundingClientRect();
866
+ const centerX = rect.width / 2;
867
+ const centerY = rect.height / 2;
868
+
869
+ // Calculate relative position from center (-1 to 1)
870
+ const relX = (event.clientX - rect.left - centerX) / centerX;
871
+ const relY = (event.clientY - rect.top - centerY) / centerY;
872
+
873
+ // Scale down for smoother movement
874
+ const scaleFactor = 0.05;
875
+ mouseState.x = relX * scaleFactor;
876
+ mouseState.y = -relY * scaleFactor; // Invert Y for intuitive camera control
877
+
878
+ // Update display
879
+ mousePosition.textContent = `Mouse: ${mouseState.x.toFixed(2)}, ${mouseState.y.toFixed(2)}`;
880
+
881
+ // Send to server (throttled)
882
+ throttledSendMouseInput();
883
+ }
884
+ });
885
+
886
+ // Throttle mouse movement to avoid flooding the server
887
+ const throttledSendMouseInput = (() => {
888
+ let lastSentTime = 0;
889
+ const interval = 50; // milliseconds
890
+
891
+ return () => {
892
+ const now = performance.now();
893
+ if (now - lastSentTime >= interval) {
894
+ sendMouseInput(mouseState.x, mouseState.y);
895
+ lastSentTime = now;
896
+ }
897
+ };
898
+ })();
899
+
900
+ // Toggle panel collapse/expand
901
+ function togglePanel(panelId) {
902
+ const panel = document.getElementById(panelId);
903
+ const button = panel.querySelector('.toggle-button');
904
+
905
+ if (panel.classList.contains('collapsed')) {
906
+ // Expand the panel
907
+ panel.classList.remove('collapsed');
908
+ button.textContent = '−'; // Minus sign
909
+ } else {
910
+ // Collapse the panel
911
+ panel.classList.add('collapsed');
912
+ button.textContent = '+'; // Plus sign
913
+ }
914
+ }
915
+
916
+ // Initialize the UI
917
+ resetUI();
918
+
919
+ // Make panel headers clickable
920
+ document.querySelectorAll('.panel-header').forEach(header => {
921
+ header.addEventListener('click', () => {
922
+ const panelId = header.parentElement.id;
923
+ togglePanel(panelId);
924
+ });
925
+ });
926
+ </script>
927
+ </body>
928
+ </html>
reference_example/Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ python3.11 \
10
+ python3-pip \
11
+ python3-dev \
12
+ git \
13
+ curl \
14
+ ffmpeg \
15
+ libglib2.0-0 \
16
+ libsm6 \
17
+ libxrender1 \
18
+ libxext6 \
19
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
20
+
21
+ WORKDIR /code
22
+
23
+ COPY ./requirements.txt /code/requirements.txt
24
+
25
+ # Set up a new user named "user" with user ID 1000
26
+ RUN useradd -m -u 1000 user
27
+ # Switch to the "user" user
28
+ USER user
29
+ # Set home to the user's home directory
30
+ ENV HOME=/home/user \
31
+ PATH=/home/user/.local/bin:$PATH
32
+
33
+ # Set home to the user's home directory
34
+ ENV PYTHONPATH=$HOME/app \
35
+ PYTHONUNBUFFERED=1 \
36
+ DATA_ROOT=/tmp/data
37
+
38
+ RUN echo "Installing requirements.txt"
39
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
40
+
41
+ # yeah.. this is manual for now
42
+ #RUN flutter build web
43
+
44
+ WORKDIR $HOME/app
45
+
46
+ COPY --chown=user . $HOME/app
47
+
48
+ EXPOSE 8080
49
+
50
+ ENV PORT 8080
51
+
52
+ CMD python3 api.py
reference_example/api.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import pathlib
6
+ import time
7
+ import uuid
8
+ from aiohttp import web, WSMsgType
9
+ from typing import Dict, Any
10
+
11
+ from api_core import VideoGenerationAPI
12
+ from api_session import SessionManager
13
+ from api_metrics import MetricsTracker
14
+ from api_config import *
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Create global session and metrics managers
24
+ session_manager = SessionManager()
25
+ metrics_tracker = MetricsTracker()
26
+
27
+ # Dictionary to track connected anonymous clients by IP address
28
+ anon_connections = {}
29
+ anon_connection_lock = asyncio.Lock()
30
+
31
+ async def status_handler(request: web.Request) -> web.Response:
32
+ """Handler for API status endpoint"""
33
+ api = session_manager.shared_api
34
+
35
+ # Get current busy status of all endpoints
36
+ endpoint_statuses = []
37
+ for ep in api.endpoint_manager.endpoints:
38
+ endpoint_statuses.append({
39
+ 'id': ep.id,
40
+ 'url': ep.url,
41
+ 'busy': ep.busy,
42
+ 'last_used': ep.last_used,
43
+ 'error_count': ep.error_count,
44
+ 'error_until': ep.error_until
45
+ })
46
+
47
+ # Get session statistics
48
+ session_stats = session_manager.get_session_stats()
49
+
50
+ # Get metrics
51
+ api_metrics = metrics_tracker.get_metrics()
52
+
53
+ return web.json_response({
54
+ 'product': PRODUCT_NAME,
55
+ 'version': PRODUCT_VERSION,
56
+ 'maintenance_mode': MAINTENANCE_MODE,
57
+ 'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS),
58
+ 'endpoint_status': endpoint_statuses,
59
+ 'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())),
60
+ 'active_sessions': session_stats,
61
+ 'metrics': api_metrics
62
+ })
63
+
64
+ async def metrics_handler(request: web.Request) -> web.Response:
65
+ """Handler for detailed metrics endpoint (protected)"""
66
+ # Check for API key in header or query param
67
+ auth_header = request.headers.get('Authorization', '')
68
+ api_key = None
69
+
70
+ if auth_header.startswith('Bearer '):
71
+ api_key = auth_header[7:]
72
+ else:
73
+ api_key = request.query.get('key')
74
+
75
+ # Validate API key (using SECRET_TOKEN as the API key)
76
+ if not api_key or api_key != SECRET_TOKEN:
77
+ return web.json_response({
78
+ 'error': 'Unauthorized'
79
+ }, status=401)
80
+
81
+ # Get detailed metrics
82
+ detailed_metrics = metrics_tracker.get_detailed_metrics()
83
+
84
+ return web.json_response(detailed_metrics)
85
+
86
+ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
87
+ # Check if maintenance mode is enabled
88
+ if MAINTENANCE_MODE:
89
+ # Return an error response indicating maintenance mode
90
+ return web.json_response({
91
+ 'error': 'Server is in maintenance mode',
92
+ 'maintenance': True
93
+ }, status=503) # 503 Service Unavailable
94
+
95
+ ws = web.WebSocketResponse(
96
+ max_msg_size=1024*1024*20, # 20MB max message size
97
+ timeout=30.0 # we want to keep things tight and short
98
+ )
99
+
100
+ await ws.prepare(request)
101
+
102
+ # Get the Hugging Face token from query parameters
103
+ hf_token = request.query.get('hf_token', '')
104
+
105
+ # Generate a unique user ID for this connection
106
+ user_id = str(uuid.uuid4())
107
+
108
+ # Validate the token and determine the user role
109
+ user_role = await session_manager.shared_api.validate_user_token(hf_token)
110
+ logger.info(f"User {user_id} connected with role: {user_role}")
111
+
112
+ # Get client IP address
113
+ peername = request.transport.get_extra_info('peername')
114
+ if peername is not None:
115
+ client_ip = peername[0]
116
+ else:
117
+ client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
118
+
119
+ logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}")
120
+
121
+ # Check for anonymous user connection limits
122
+ if user_role == 'anon':
123
+ async with anon_connection_lock:
124
+ # Track this connection
125
+ anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1
126
+ # Store the IP so we can clean up later
127
+ ws.client_ip = client_ip
128
+
129
+ # Log multiple connections from same IP but don't restrict them
130
+ if anon_connections[client_ip] > 1:
131
+ logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections")
132
+
133
+ # Store the user role in the websocket for easy access
134
+ ws.user_role = user_role
135
+ ws.user_id = user_id
136
+
137
+ # Register with metrics
138
+ metrics_tracker.register_session(user_id, client_ip)
139
+
140
+ # Create a new session for this user
141
+ user_session = await session_manager.create_session(user_id, user_role, ws)
142
+
143
+ try:
144
+ async for msg in ws:
145
+ if msg.type == WSMsgType.TEXT:
146
+ try:
147
+ data = json.loads(msg.data)
148
+ action = data.get('action')
149
+
150
+ # Check for rate limiting
151
+ request_type = 'other'
152
+ if action in ['join_chat', 'leave_chat', 'chat_message']:
153
+ request_type = 'chat'
154
+ elif action in ['generate_video']:
155
+ request_type = 'video'
156
+ elif action == 'search':
157
+ request_type = 'search'
158
+ elif action == 'simulate':
159
+ request_type = 'simulation'
160
+
161
+ # Record the request for metrics
162
+ await metrics_tracker.record_request(user_id, client_ip, request_type, user_role)
163
+
164
+ # Check rate limits (except for admins)
165
+ if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role):
166
+ await ws.send_json({
167
+ 'action': action,
168
+ 'requestId': data.get('requestId'),
169
+ 'success': False,
170
+ 'error': f'Rate limit exceeded for {request_type} requests. Please try again later.'
171
+ })
172
+ continue
173
+
174
+ # Route requests to appropriate queues
175
+ if action in ['join_chat', 'leave_chat', 'chat_message']:
176
+ await user_session.chat_queue.put(data)
177
+ elif action in ['generate_video']:
178
+ await user_session.video_queue.put(data)
179
+ elif action == 'search':
180
+ await user_session.search_queue.put(data)
181
+ elif action == 'simulate':
182
+ await user_session.simulation_queue.put(data)
183
+ else:
184
+ await user_session.process_generic_request(data)
185
+
186
+ except Exception as e:
187
+ logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
188
+ await ws.send_json({
189
+ 'action': data.get('action') if 'data' in locals() else 'unknown',
190
+ 'success': False,
191
+ 'error': f'Error processing message: {str(e)}'
192
+ })
193
+
194
+ elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
195
+ break
196
+
197
+ finally:
198
+ # Cleanup session
199
+ await session_manager.delete_session(user_id)
200
+
201
+ # Cleanup anonymous connection tracking
202
+ if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'):
203
+ client_ip = ws.client_ip
204
+ async with anon_connection_lock:
205
+ if client_ip in anon_connections:
206
+ anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1)
207
+ if anon_connections[client_ip] == 0:
208
+ del anon_connections[client_ip]
209
+ logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}")
210
+
211
+ # Unregister from metrics
212
+ metrics_tracker.unregister_session(user_id, client_ip)
213
+ logger.info(f"Connection closed for user {user_id}")
214
+
215
+ return ws
216
+
217
+ async def init_app() -> web.Application:
218
+ app = web.Application(
219
+ client_max_size=1024**2*20 # 20MB max size
220
+ )
221
+
222
+ # Add cleanup logic
223
+ async def cleanup(app):
224
+ logger.info("Shutting down server, closing all sessions...")
225
+ await session_manager.close_all_sessions()
226
+
227
+ app.on_shutdown.append(cleanup)
228
+
229
+ # Add routes
230
+ app.router.add_get('/ws', websocket_handler)
231
+ app.router.add_get('/api/status', status_handler)
232
+ app.router.add_get('/api/metrics', metrics_handler)
233
+
234
+ # Set up static file serving
235
+ # Define the path to the public directory
236
+ public_path = pathlib.Path(__file__).parent / 'build' / 'web'
237
+ if not public_path.exists():
238
+ public_path.mkdir(parents=True, exist_ok=True)
239
+
240
+ # Set up static file serving with proper security considerations
241
+ async def static_file_handler(request):
242
+ # Get the path from the request (removing leading /)
243
+ path_parts = request.path.lstrip('/').split('/')
244
+
245
+ # Convert to safe path to prevent path traversal attacks
246
+ safe_path = public_path.joinpath(*path_parts)
247
+
248
+ # Make sure the path is within the public directory (prevent directory traversal)
249
+ try:
250
+ safe_path = safe_path.resolve()
251
+ if not str(safe_path).startswith(str(public_path.resolve())):
252
+ return web.HTTPForbidden(text="Access denied")
253
+ except (ValueError, FileNotFoundError):
254
+ return web.HTTPNotFound()
255
+
256
+ # If path is a directory, look for index.html
257
+ if safe_path.is_dir():
258
+ safe_path = safe_path / 'index.html'
259
+
260
+ # Check if the file exists
261
+ if not safe_path.exists() or not safe_path.is_file():
262
+ # If not found, serve index.html (for SPA routing)
263
+ safe_path = public_path / 'index.html'
264
+ if not safe_path.exists():
265
+ return web.HTTPNotFound()
266
+
267
+ # Determine content type based on file extension
268
+ content_type = 'text/plain'
269
+ ext = safe_path.suffix.lower()
270
+ if ext == '.html':
271
+ content_type = 'text/html'
272
+ elif ext == '.js':
273
+ content_type = 'application/javascript'
274
+ elif ext == '.css':
275
+ content_type = 'text/css'
276
+ elif ext in ('.jpg', '.jpeg'):
277
+ content_type = 'image/jpeg'
278
+ elif ext == '.png':
279
+ content_type = 'image/png'
280
+ elif ext == '.gif':
281
+ content_type = 'image/gif'
282
+ elif ext == '.svg':
283
+ content_type = 'image/svg+xml'
284
+ elif ext == '.json':
285
+ content_type = 'application/json'
286
+
287
+ # Return the file with appropriate headers
288
+ return web.FileResponse(safe_path, headers={'Content-Type': content_type})
289
+
290
+ # Add catch-all route for static files (lower priority than API routes)
291
+ app.router.add_get('/{path:.*}', static_file_handler)
292
+
293
+ return app
294
+
295
+ if __name__ == '__main__':
296
+ app = asyncio.run(init_app())
297
+ web.run_app(app, host='0.0.0.0', port=8080)
reference_example/api_config.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ PRODUCT_NAME = os.environ.get('PRODUCT_NAME', 'TikSlop')
4
+ PRODUCT_VERSION = "2.0.0"
5
+
6
+ # you should use Mistral 7b instruct for good performance and accuracy balance
7
+ TEXT_MODEL = os.environ.get('HF_TEXT_MODEL', '')
8
+
9
+ # Environment variable to control maintenance mode
10
+ MAINTENANCE_MODE = os.environ.get('MAINTENANCE_MODE', 'false').lower() in ('true', 'yes', '1', 't')
11
+
12
+ # Environment variable to control how many nodes to use
13
+ MAX_NODES = int(os.environ.get('MAX_NODES', '8'))
14
+
15
+ ADMIN_ACCOUNTS = [
16
+ "jbilcke-hf"
17
+ ]
18
+
19
+ RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS = [
20
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_1', ''),
21
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_2', ''),
22
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_3', ''),
23
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_4', ''),
24
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_5', ''),
25
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_6', ''),
26
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_7', ''),
27
+ os.environ.get('VIDEO_ROUND_ROBIN_SERVER_8', ''),
28
+ ]
29
+
30
+ # Filter out empty strings from the endpoint list
31
+ filtered_urls = [url for url in RAW_VIDEO_ROUND_ROBIN_ENDPOINT_URLS if url]
32
+
33
+ # Limit the number of URLs based on MAX_NODES environment variable
34
+ VIDEO_ROUND_ROBIN_ENDPOINT_URLS = filtered_urls[:MAX_NODES]
35
+
36
+ HF_TOKEN = os.environ.get('HF_TOKEN')
37
+
38
+ # use the same secret token as you used to secure your BASE_SPACE_NAME spaces
39
+ SECRET_TOKEN = os.environ.get('SECRET_TOKEN')
40
+
41
+ # altenative words we could use: "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"
42
+ NEGATIVE_PROMPT = "low quality, worst quality, deformed, distorted, disfigured, blurry, text, watermark"
43
+
44
+ POSITIVE_PROMPT_SUFFIX = "high quality, cinematic, 4K, intricate details"
45
+
46
+ GUIDANCE_SCALE = 1.0
47
+
48
+ THUMBNAIL_FRAMES = 65
49
+
50
+ # anonymous users are people browing TikSlop without being connected
51
+ # this category suffers from regular abuse so we need to enforce strict limitations
52
+ CONFIG_FOR_ANONYMOUS_USERS = {
53
+
54
+ # anons can only watch 2 minutes per video
55
+ "max_rendering_time_per_client_per_video_in_sec": 2 * 60,
56
+
57
+ "min_num_inference_steps": 2,
58
+ "default_num_inference_steps": 4,
59
+ "max_num_inference_steps": 4,
60
+
61
+ "min_num_frames": 9, # 8 + 1
62
+ "default_max_num_frames": 65, # 8*8 + 1
63
+ "max_num_frames": 65, # 8*8 + 1
64
+
65
+ "min_clip_duration_seconds": 1,
66
+ "default_clip_duration_seconds": 2,
67
+ "max_clip_duration_seconds": 2,
68
+
69
+ "min_clip_playback_speed": 0.7,
70
+ "default_clip_playback_speed": 0.7,
71
+ "max_clip_playback_speed": 0.7,
72
+
73
+ "min_clip_framerate": 8,
74
+ "default_clip_framerate": 16,
75
+ "max_clip_framerate": 16,
76
+
77
+ "min_clip_width": 544,
78
+ "default_clip_width": 640,
79
+ "max_clip_width": 640,
80
+
81
+ "min_clip_height": 320,
82
+ "default_clip_height": 352,
83
+ "max_clip_height": 352,
84
+ }
85
+
86
+ # Hugging Face users enjoy a more normal and calibrated experience
87
+ CONFIG_FOR_STANDARD_HF_USERS = {
88
+ "max_rendering_time_per_client_per_video_in_sec": 15 * 60,
89
+
90
+ "min_num_inference_steps": 2,
91
+ "default_num_inference_steps": 4,
92
+ "max_num_inference_steps": 4,
93
+
94
+ "min_num_frames": 9, # 8 + 1
95
+ "default_num_frames": 81, # 8*10 + 1
96
+ "max_num_frames": 81,
97
+
98
+ "min_clip_duration_seconds": 1,
99
+ "default_clip_duration_seconds": 3,
100
+ "max_clip_duration_seconds": 3,
101
+
102
+ "min_clip_playback_speed": 0.7,
103
+ "default_clip_playback_speed": 0.7,
104
+ "max_clip_playback_speed": 0.7,
105
+
106
+ "min_clip_framerate": 8,
107
+ "default_clip_framerate": 25,
108
+ "max_clip_framerate": 25,
109
+
110
+ "min_clip_width": 544,
111
+ "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
112
+ "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
113
+
114
+ "min_clip_height": 320,
115
+ "default_clip_height": 640, # 512, # 448, # 416,
116
+ "max_clip_height": 640, # 512, # 448, # 416,
117
+ }
118
+
119
+ # Hugging Face users with a Pro may enjoy an improved experience
120
+ CONFIG_FOR_PRO_HF_USERS = {
121
+ "max_rendering_time_per_client_per_video_in_sec": 20 * 60,
122
+
123
+ "min_num_inference_steps": 2,
124
+ "default_num_inference_steps": 4,
125
+ "max_num_inference_steps": 4,
126
+
127
+ "min_num_frames": 9, # 8 + 1
128
+ "default_num_frames": 81, # 8*10 + 1
129
+ "max_num_frames": 81,
130
+
131
+ "min_clip_duration_seconds": 1,
132
+ "default_clip_duration_seconds": 3,
133
+ "max_clip_duration_seconds": 3,
134
+
135
+ "min_clip_playback_speed": 0.7,
136
+ "default_clip_playback_speed": 0.7,
137
+ "max_clip_playback_speed": 0.7,
138
+
139
+ "min_clip_framerate": 8,
140
+ "default_clip_framerate": 25,
141
+ "max_clip_framerate": 25,
142
+
143
+ "min_clip_width": 544,
144
+ "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
145
+ "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
146
+
147
+ "min_clip_height": 320,
148
+ "default_clip_height": 640, # 512, # 448, # 416,
149
+ "max_clip_height": 640, # 512, # 448, # 416,
150
+ }
151
+
152
+ CONFIG_FOR_ADMIN_HF_USERS = {
153
+ "max_rendering_time_per_client_per_video_in_sec": 60 * 60,
154
+
155
+ "min_num_inference_steps": 2,
156
+ "default_num_inference_steps": 4,
157
+ "max_num_inference_steps": 4,
158
+
159
+ "min_num_frames": 9, # 8 + 1
160
+ "default_num_frames": 81, # (8 * 10) + 1
161
+ "max_num_frames": 129, # (8 * 16) + 1
162
+
163
+ "min_clip_duration_seconds": 1,
164
+ "default_clip_duration_seconds": 2,
165
+ "max_clip_duration_seconds": 4,
166
+
167
+ "min_clip_playback_speed": 0.7,
168
+ "default_clip_playback_speed": 0.7,
169
+ "max_clip_playback_speed": 1.0,
170
+
171
+ "min_clip_framerate": 8,
172
+ "default_clip_framerate": 30,
173
+ "max_clip_framerate": 60,
174
+
175
+ "min_clip_width": 544,
176
+ "default_clip_width": 1152, # 928, # 1216, # 768, # 640,
177
+ "max_clip_width": 1152, # 928, # 1216, # 768, # 640,
178
+
179
+ "min_clip_height": 320,
180
+ "default_clip_height": 640, # 512, # 448, # 416,
181
+ "max_clip_height": 640, # 512, # 448, # 416,
182
+ }
183
+
184
+ CONFIG_FOR_ADMIN_HF_USERS = CONFIG_FOR_PRO_HF_USERS
reference_example/api_core.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import io
4
+ import re
5
+ import base64
6
+ import uuid
7
+ from typing import Dict, Any, Optional, List, Literal
8
+ from dataclasses import dataclass
9
+ from asyncio import Lock, Queue
10
+ import asyncio
11
+ import time
12
+ import datetime
13
+ from contextlib import asynccontextmanager
14
+ from collections import defaultdict
15
+ from aiohttp import web, ClientSession
16
+ from huggingface_hub import InferenceClient, HfApi
17
+ from gradio_client import Client
18
+ import random
19
+ import yaml
20
+ import json
21
+
22
+ from api_config import *
23
+
24
+ # User role type
25
+ UserRole = Literal['anon', 'normal', 'pro', 'admin']
26
+
27
+ # Configure logging
28
+ logging.basicConfig(
29
+ level=logging.INFO,
30
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
31
+ )
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def generate_seed():
36
+ """Generate a random positive 32-bit integer seed."""
37
+ return random.randint(0, 2**32 - 1)
38
+
39
+ def sanitize_yaml_response(response_text: str) -> str:
40
+ """
41
+ Sanitize and format AI response into valid YAML.
42
+ Returns properly formatted YAML string.
43
+ """
44
+
45
+ response_text = response_text.split("```")[0]
46
+
47
+ # Remove any markdown code block indicators and YAML document markers
48
+ clean_text = re.sub(r'```yaml|```|---|\.\.\.$', '', response_text.strip())
49
+
50
+ # Split into lines and process each line
51
+ lines = clean_text.split('\n')
52
+ sanitized_lines = []
53
+ current_field = None
54
+
55
+ for line in lines:
56
+ stripped = line.strip()
57
+ if not stripped:
58
+ continue
59
+
60
+ # Handle field starts
61
+ if stripped.startswith('title:') or stripped.startswith('description:'):
62
+ # Ensure proper YAML format with space after colon and proper quoting
63
+ field_name = stripped.split(':', 1)[0]
64
+ field_value = stripped.split(':', 1)[1].strip().strip('"\'')
65
+
66
+ # Quote the value if it contains special characters
67
+ if any(c in field_value for c in ':[]{},&*#?|-<>=!%@`'):
68
+ field_value = f'"{field_value}"'
69
+
70
+ sanitized_lines.append(f"{field_name}: {field_value}")
71
+ current_field = field_name
72
+
73
+ elif stripped.startswith('tags:'):
74
+ sanitized_lines.append('tags:')
75
+ current_field = 'tags'
76
+
77
+ elif stripped.startswith('-') and current_field == 'tags':
78
+ # Process tag values
79
+ tag = stripped[1:].strip().strip('"\'')
80
+ if tag:
81
+ # Clean and format tag
82
+ tag = re.sub(r'[^\x00-\x7F]+', '', tag) # Remove non-ASCII
83
+ tag = re.sub(r'[^a-zA-Z0-9\s-]', '', tag) # Keep only alphanumeric and hyphen
84
+ tag = tag.strip().lower().replace(' ', '-')
85
+ if tag:
86
+ sanitized_lines.append(f" - {tag}")
87
+
88
+ elif current_field in ['title', 'description']:
89
+ # Handle multi-line title/description continuation
90
+ value = stripped.strip('"\'')
91
+ if value:
92
+ # Append to previous line
93
+ prev = sanitized_lines[-1]
94
+ sanitized_lines[-1] = f"{prev} {value}"
95
+
96
+ # Ensure the YAML has all required fields
97
+ required_fields = {'title', 'description', 'tags'}
98
+ found_fields = {line.split(':')[0].strip() for line in sanitized_lines if ':' in line}
99
+
100
+ for field in required_fields - found_fields:
101
+ if field == 'tags':
102
+ sanitized_lines.extend(['tags:', ' - default'])
103
+ else:
104
+ sanitized_lines.append(f'{field}: "No {field} provided"')
105
+
106
+ return '\n'.join(sanitized_lines)
107
+
108
+ @dataclass
109
+ class Endpoint:
110
+ id: int
111
+ url: str
112
+ busy: bool = False
113
+ last_used: float = 0
114
+ error_count: int = 0
115
+ error_until: float = 0 # Timestamp until which this endpoint is considered in error state
116
+
117
+ class EndpointManager:
118
+ def __init__(self):
119
+ self.endpoints: List[Endpoint] = []
120
+ self.lock = Lock()
121
+ self.initialize_endpoints()
122
+ self.last_used_index = -1 # Track the last used endpoint for round-robin
123
+
124
+ def initialize_endpoints(self):
125
+ """Initialize the list of endpoints"""
126
+ for i, url in enumerate(VIDEO_ROUND_ROBIN_ENDPOINT_URLS):
127
+ endpoint = Endpoint(id=i + 1, url=url)
128
+ self.endpoints.append(endpoint)
129
+
130
+ def _get_next_free_endpoint(self):
131
+ """Get the next available non-busy endpoint, or oldest endpoint if all are busy"""
132
+ current_time = time.time()
133
+
134
+ # First priority: Get any non-busy and non-error endpoint
135
+ free_endpoints = [
136
+ ep for ep in self.endpoints
137
+ if not ep.busy and current_time > ep.error_until
138
+ ]
139
+
140
+ if free_endpoints:
141
+ # Return the least recently used free endpoint
142
+ return min(free_endpoints, key=lambda ep: ep.last_used)
143
+
144
+ # Second priority: If all busy/error, use round-robin but skip error endpoints
145
+ tried_count = 0
146
+ next_index = self.last_used_index
147
+
148
+ while tried_count < len(self.endpoints):
149
+ next_index = (next_index + 1) % len(self.endpoints)
150
+ tried_count += 1
151
+
152
+ # If endpoint is not in error state, use it
153
+ if current_time > self.endpoints[next_index].error_until:
154
+ self.last_used_index = next_index
155
+ return self.endpoints[next_index]
156
+
157
+ # If all endpoints are in error state, use the one with earliest error expiry
158
+ self.last_used_index = next_index
159
+ return min(self.endpoints, key=lambda ep: ep.error_until)
160
+
161
+ @asynccontextmanager
162
+ async def get_endpoint(self, max_wait_time: int = 10):
163
+ """Get the next available endpoint using a context manager"""
164
+ start_time = time.time()
165
+ endpoint = None
166
+
167
+ try:
168
+ while True:
169
+ if time.time() - start_time > max_wait_time:
170
+ raise TimeoutError(f"Could not acquire an endpoint within {max_wait_time} seconds")
171
+
172
+ async with self.lock:
173
+ # Get the next available endpoint using our selection strategy
174
+ endpoint = self._get_next_free_endpoint()
175
+
176
+ # Mark it as busy
177
+ endpoint.busy = True
178
+ endpoint.last_used = time.time()
179
+ #logger.info(f"Using endpoint {endpoint.id} (busy: {endpoint.busy}, last used: {endpoint.last_used})")
180
+ break
181
+
182
+ yield endpoint
183
+
184
+ finally:
185
+ if endpoint:
186
+ async with self.lock:
187
+ endpoint.busy = False
188
+ endpoint.last_used = time.time()
189
+ # We don't need to put back into queue - our strategy now picks directly from the list
190
+
191
+ class ChatRoom:
192
+ def __init__(self):
193
+ self.messages = []
194
+ self.connected_clients = set()
195
+ self.max_history = 100
196
+
197
+ def add_message(self, message):
198
+ self.messages.append(message)
199
+ if len(self.messages) > self.max_history:
200
+ self.messages.pop(0)
201
+
202
+ def get_recent_messages(self, limit=50):
203
+ return self.messages[-limit:]
204
+
205
+ class VideoGenerationAPI:
206
+ def __init__(self):
207
+ self.inference_client = InferenceClient(token=HF_TOKEN)
208
+ self.hf_api = HfApi(token=HF_TOKEN)
209
+ self.endpoint_manager = EndpointManager()
210
+ self.active_requests: Dict[str, asyncio.Future] = {}
211
+ self.chat_rooms = defaultdict(ChatRoom)
212
+ self.video_events: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
213
+ self.event_history_limit = 50
214
+ # Cache for user roles to avoid repeated API calls
215
+ self.user_role_cache: Dict[str, Dict[str, Any]] = {}
216
+ # Cache expiration time (10 minutes)
217
+ self.cache_expiration = 600
218
+
219
+
220
+ def _add_event(self, video_id: str, event: Dict[str, Any]):
221
+ """Add an event to the video's history and maintain the size limit"""
222
+ events = self.video_events[video_id]
223
+ events.append(event)
224
+ if len(events) > self.event_history_limit:
225
+ events.pop(0)
226
+
227
+ async def validate_user_token(self, token: str) -> UserRole:
228
+ """
229
+ Validates a Hugging Face token and determines the user's role.
230
+
231
+ Returns one of:
232
+ - 'anon': Anonymous user (no token or invalid token)
233
+ - 'normal': Standard Hugging Face user
234
+ - 'pro': Hugging Face Pro user
235
+ - 'admin': Admin user (username in ADMIN_ACCOUNTS)
236
+ """
237
+ # If no token is provided, the user is anonymous
238
+ if not token:
239
+ return 'anon'
240
+
241
+ # Check if we have a cached result for this token
242
+ current_time = time.time()
243
+ if token in self.user_role_cache:
244
+ cached_data = self.user_role_cache[token]
245
+ # If the cache is still valid
246
+ if current_time - cached_data['timestamp'] < self.cache_expiration:
247
+ logger.info(f"Using cached user role: {cached_data['role']}")
248
+ return cached_data['role']
249
+
250
+ # No valid cache, need to check the token with the HF API
251
+ try:
252
+ # Use HF API to validate the token and get user info
253
+ logger.info("Validating Hugging Face token...")
254
+
255
+ # Run in executor to avoid blocking the event loop
256
+ user_info = await asyncio.get_event_loop().run_in_executor(
257
+ None,
258
+ lambda: self.hf_api.whoami(token=token)
259
+ )
260
+
261
+ # Handle both object and dict response formats from whoami
262
+ username = user_info.get('name') if isinstance(user_info, dict) else getattr(user_info, 'name', None)
263
+ is_pro = user_info.get('is_pro') if isinstance(user_info, dict) else getattr(user_info, 'is_pro', False)
264
+
265
+ if not username:
266
+ logger.error(f"Could not determine username from user_info: {user_info}")
267
+ return 'anon'
268
+
269
+ logger.info(f"Token valid for user: {username}")
270
+
271
+ # Determine the user role based on the information
272
+ user_role: UserRole
273
+
274
+ # Check if the user is an admin
275
+ if username in ADMIN_ACCOUNTS:
276
+ user_role = 'admin'
277
+ # Check if the user has a pro account
278
+ elif is_pro:
279
+ user_role = 'pro'
280
+ else:
281
+ user_role = 'normal'
282
+
283
+ # Cache the result
284
+ self.user_role_cache[token] = {
285
+ 'role': user_role,
286
+ 'timestamp': current_time,
287
+ 'username': username
288
+ }
289
+
290
+ return user_role
291
+
292
+ except Exception as e:
293
+ logger.error(f"Failed to validate Hugging Face token: {str(e)}")
294
+ # If validation fails, the user is treated as anonymous
295
+ return 'anon'
296
+
297
+ async def download_video(self, url: str) -> bytes:
298
+ """Download video file from URL and return bytes"""
299
+ async with ClientSession() as session:
300
+ async with session.get(url) as response:
301
+ if response.status != 200:
302
+ raise Exception(f"Failed to download video: HTTP {response.status}")
303
+ return await response.read()
304
+
305
+ async def search_video(self, query: str, attempt_count: int = 0) -> Optional[dict]:
306
+ """Generate a single search result using HF text generation"""
307
+ # Maximum number of attempts to generate a description without placeholder tags
308
+ max_attempts = 2
309
+ current_attempt = attempt_count
310
+ # Use a random temperature between 0.68 and 0.72 to generate more diverse results
311
+ # and prevent duplicate results from successive calls with the same prompt
312
+ temperature = random.uniform(0.68, 0.72)
313
+
314
+ while current_attempt <= max_attempts:
315
+ prompt = f"""# Instruction
316
+ Your response MUST be a YAML object containing a title and description, consistent with what we can find on a video sharing platform.
317
+ Format your YAML response with only those fields: "title" (a short string) and "description" (string caption of the scene). Do not add any other field.
318
+ In the description field, describe in a very synthetic way the visuals of the first shot (first scene), eg "<STYLE>, medium close-up shot, high angle view. In the foreground a <OPTIONAL AGE> <OPTIONAL GENDER> <CHARACTERS> <ACTIONS>. In the background <DESCRIBE LOCATION, BACKGROUND CHARACTERS, OBJECTS ETC>. The scene is lit by <LIGHTING> <WEATHER>". This is just an example! you MUST replace the <TAGS>!!.
319
+ Don't forget to replace <STYLE> etc, by the actual fields!!
320
+ For the style, be creative, for instance you can use anything like a "documentary footage", "japanese animation", "movie scene", "tv series", "tv show", "security footage" etc.
321
+ If the user ask for something specific eg "movie screencap", "movie scene", "documentary footage" "animation" as a style etc.
322
+ Keep it minimalist but still descriptive, don't use bullets points, use simple words, go to the essential to describe style (cinematic, documentary footage, 3D rendering..), camera modes and angles, characters, age, gender, action, location, lighting, country, costume, time, weather, textures, color palette.. etc). Write about 80 words, and use between 2 and 3 sentences.
323
+ The most import part is to describe the actions and movements in the scene, so don't forget that!
324
+ Don't describe sound, so ever say things like "atmospheric music playing in the background".
325
+ Instead describe the visual elements we can see in the background, be precise, (if there are anything, cars, objects, people, bricks, birds, clouds, trees, leaves or grass then say it so etc).
326
+ Make the result unique and different from previous search results. ONLY RETURN YAML AND WITH ENGLISH CONTENT, NOT CHINESE - DO NOT ADD ANY OTHER COMMENT!
327
+
328
+ # Context
329
+ This is attempt {current_attempt}.
330
+
331
+ # Input
332
+ Describe the first scene/shot for: "{query}".
333
+
334
+ # Output
335
+
336
+ ```yaml
337
+ title: \""""
338
+
339
+ try:
340
+ response = await asyncio.get_event_loop().run_in_executor(
341
+ None,
342
+ lambda: self.inference_client.text_generation(
343
+ prompt,
344
+ model=TEXT_MODEL,
345
+ max_new_tokens=200,
346
+ temperature=temperature
347
+ )
348
+ )
349
+
350
+ response_text = re.sub(r'^\s*\.\s*\n', '', f"title: \"{response.strip()}")
351
+ sanitized_yaml = sanitize_yaml_response(response_text)
352
+
353
+ try:
354
+ result = yaml.safe_load(sanitized_yaml)
355
+ except yaml.YAMLError as e:
356
+ logger.error(f"YAML parsing failed: {str(e)}")
357
+ result = None
358
+
359
+ if not result or not isinstance(result, dict):
360
+ logger.error(f"Invalid result format: {result}")
361
+ current_attempt += 1
362
+ temperature = random.uniform(0.68, 0.72) # Try with different random temperature on next attempt
363
+ continue
364
+
365
+ # Extract fields with defaults
366
+ title = str(result.get('title', '')).strip() or 'Untitled Video'
367
+ description = str(result.get('description', '')).strip() or 'No description available'
368
+
369
+ # Check if the description still contains placeholder tags like <LOCATION>, <GENDER>, etc.
370
+ if re.search(r'<[A-Z_]+>', description):
371
+ #logger.warning(f"Description still contains placeholder tags: {description}")
372
+ if current_attempt < max_attempts:
373
+ # Try again with a different random temperature
374
+ current_attempt += 1
375
+ temperature = random.uniform(0.68, 0.72)
376
+ continue
377
+ else:
378
+ # If we've reached max attempts, use the title as description
379
+ description = title
380
+
381
+ # Return valid result with all required fields
382
+ return {
383
+ 'id': str(uuid.uuid4()),
384
+ 'title': title,
385
+ 'description': description,
386
+ 'thumbnailUrl': '',
387
+ 'videoUrl': '',
388
+
389
+ # not really used yet, maybe one day if we pre-generate or store content
390
+ 'isLatent': True,
391
+
392
+ 'useFixedSeed': "webcam" in description.lower(),
393
+
394
+ 'seed': generate_seed(),
395
+ 'views': 0,
396
+ 'tags': []
397
+ }
398
+
399
+ except Exception as e:
400
+ logger.error(f"Search video generation failed: {str(e)}")
401
+ current_attempt += 1
402
+ temperature = random.uniform(0.68, 0.72) # Try with different random temperature on next attempt
403
+
404
+ # If all attempts failed, return a simple result with title only
405
+ return {
406
+ 'id': str(uuid.uuid4()),
407
+ 'title': f"Video about {query}",
408
+ 'description': f"Video about {query}",
409
+ 'thumbnailUrl': '',
410
+ 'videoUrl': '',
411
+ 'isLatent': True,
412
+ 'useFixedSeed': "query" in description.lower(),
413
+ 'seed': generate_seed(),
414
+ 'views': 0,
415
+ 'tags': []
416
+ }
417
+
418
+ # The generate_thumbnail function has been removed because we now use
419
+ # generate_video_thumbnail for all thumbnails, which generates a video clip
420
+ # instead of a static image
421
+
422
+ async def generate_caption(self, title: str, description: str) -> str:
423
+ """Generate detailed caption using HF text generation"""
424
+ try:
425
+ prompt = f"""Generate a detailed story for a video named: "{title}"
426
+ Visual description of the video: {description}.
427
+ Instructions: Write the story summary, including the plot, action, what should happen.
428
+ Make it around 200-300 words long.
429
+ A video can be anything from a tutorial, webcam, trailer, movie, live stream etc."""
430
+
431
+ response = await asyncio.get_event_loop().run_in_executor(
432
+ None,
433
+ lambda: self.inference_client.text_generation(
434
+ prompt,
435
+ model=TEXT_MODEL,
436
+ max_new_tokens=180,
437
+ temperature=0.7
438
+ )
439
+ )
440
+
441
+ if "Caption: " in response:
442
+ response = response.replace("Caption: ", "")
443
+
444
+ chunks = f" {response} ".split(". ")
445
+ if len(chunks) > 1:
446
+ text = ". ".join(chunks[:-1])
447
+ else:
448
+ text = response
449
+
450
+ return text.strip()
451
+ except Exception as e:
452
+ logger.error(f"Error generating caption: {str(e)}")
453
+ return ""
454
+
455
+ async def simulate(self, original_title: str, original_description: str,
456
+ current_description: str, condensed_history: str,
457
+ evolution_count: int = 0, chat_messages: str = '') -> dict:
458
+ """
459
+ Simulate a video by evolving its description to create a dynamic narrative.
460
+
461
+ Args:
462
+ original_title: The original video title
463
+ original_description: The original video description
464
+ current_description: The current description (last evolved or original if first evolution)
465
+ condensed_history: A condensed summary of previous scene developments
466
+ evolution_count: How many times the simulation has already evolved
467
+ chat_messages: Chat messages from users to incorporate into the simulation
468
+
469
+ Returns:
470
+ A dictionary containing the evolved description and updated condensed history
471
+ """
472
+ try:
473
+ # Determine if this is the first simulation
474
+ is_first_simulation = evolution_count == 0 or not condensed_history
475
+
476
+ logger.info(f"simulate(): is_first_simulation={is_first_simulation}")
477
+
478
+ # Create an appropriate prompt based on whether this is the first simulation
479
+ chat_section = ""
480
+ if chat_messages:
481
+ chat_section = f"""
482
+ People are watching this content right now and have shared their thoughts. Like a game master, please take their feedback as input to adjust the story and/or the scene. Here are their messages:
483
+
484
+ {chat_messages}
485
+ """
486
+
487
+ if is_first_simulation:
488
+ prompt = f"""You are tasked with evolving the narrative for a video titled: "{original_title}"
489
+
490
+ Original description:
491
+ {original_description}
492
+ {chat_section}
493
+
494
+ Instructions:
495
+ 1. Imagine the next logical scene or development that would follow this description.
496
+ 2. Create a compelling new description (200-300 words) that builds on the original but introduces new elements, developments, or perspectives.
497
+ 3. Maintain the original style, tone, and setting.
498
+ 4. If viewers have shared messages, consider their input and incorporate relevant suggestions or reactions into your narrative evolution.
499
+ 5. Also create a brief "scene history" (50-75 words) that summarizes what has happened so far.
500
+
501
+ Return your response in this format:
502
+ EVOLVED_DESCRIPTION: [your new evolved description here]
503
+ CONDENSED_HISTORY: [your scene history summary]"""
504
+ else:
505
+ prompt = f"""You are tasked with continuing to evolve the narrative for a video titled: "{original_title}"
506
+
507
+ Original description:
508
+ {original_description}
509
+
510
+ Condensed history of scenes so far:
511
+ {condensed_history}
512
+
513
+ Current description (most recent scene):
514
+ {current_description}
515
+ {chat_section}
516
+
517
+ Instructions:
518
+ 1. Imagine the next logical scene or development that would follow the current description.
519
+ 2. Create a compelling new description (200-300 words) that builds on the narrative but introduces new elements, developments, or perspectives.
520
+ 3. Maintain consistency with the previous scenes while advancing the story.
521
+ 4. If viewers have shared messages, consider their input and incorporate relevant suggestions or reactions into your narrative evolution.
522
+ 5. Also update the condensed history (50-75 words) to include this new development.
523
+
524
+ Return your response in this format:
525
+ EVOLVED_DESCRIPTION: [your new evolved description here]
526
+ CONDENSED_HISTORY: [your updated scene history summary]"""
527
+
528
+ # Generate the evolved description
529
+ response = await asyncio.get_event_loop().run_in_executor(
530
+ None,
531
+ lambda: self.inference_client.text_generation(
532
+ prompt,
533
+ model=TEXT_MODEL,
534
+ max_new_tokens=200,
535
+ temperature=0.7
536
+ )
537
+ )
538
+
539
+ # Extract the evolved description and condensed history from the response
540
+ evolved_description = ""
541
+ new_condensed_history = ""
542
+
543
+ # Parse the response
544
+ if "EVOLVED_DESCRIPTION:" in response and "CONDENSED_HISTORY:" in response:
545
+ parts = response.split("CONDENSED_HISTORY:")
546
+ if len(parts) >= 2:
547
+ desc_part = parts[0].strip()
548
+ if "EVOLVED_DESCRIPTION:" in desc_part:
549
+ evolved_description = desc_part.split("EVOLVED_DESCRIPTION:", 1)[1].strip()
550
+ new_condensed_history = parts[1].strip()
551
+
552
+ # If parsing failed, use some fallbacks
553
+ if not evolved_description:
554
+ evolved_description = current_description
555
+ logger.warning(f"Failed to parse evolved description, using current description as fallback")
556
+
557
+ if not new_condensed_history and condensed_history:
558
+ new_condensed_history = condensed_history
559
+ logger.warning(f"Failed to parse condensed history, using current history as fallback")
560
+ elif not new_condensed_history:
561
+ new_condensed_history = f"The video begins with {original_title}: {original_description[:100]}..."
562
+
563
+ return {
564
+ "evolved_description": evolved_description,
565
+ "condensed_history": new_condensed_history
566
+ }
567
+
568
+ except Exception as e:
569
+ logger.error(f"Error simulating video: {str(e)}")
570
+ return {
571
+ "evolved_description": current_description,
572
+ "condensed_history": condensed_history or f"The video shows {original_title}."
573
+ }
574
+
575
+
576
+ def get_config_value(self, role: UserRole, field: str, options: dict = None) -> Any:
577
+ """
578
+ Get the appropriate config value for a user role.
579
+
580
+ Args:
581
+ role: The user role ('anon', 'normal', 'pro', 'admin')
582
+ field: The config field name to retrieve
583
+ options: Optional user-provided options that may override defaults
584
+
585
+ Returns:
586
+ The config value appropriate for the user's role with respect to
587
+ min/max boundaries and user overrides.
588
+ """
589
+ # Select the appropriate config based on user role
590
+ if role == 'admin':
591
+ config = CONFIG_FOR_ADMIN_HF_USERS
592
+ elif role == 'pro':
593
+ config = CONFIG_FOR_PRO_HF_USERS
594
+ elif role == 'normal':
595
+ config = CONFIG_FOR_STANDARD_HF_USERS
596
+ else: # Anonymous users
597
+ config = CONFIG_FOR_ANONYMOUS_USERS
598
+
599
+ # Get the default value for this field from the config
600
+ default_value = config.get(f"default_{field}", None)
601
+
602
+ # For fields that have min/max bounds
603
+ min_field = f"min_{field}"
604
+ max_field = f"max_{field}"
605
+
606
+ # Check if min/max constraints exist for this field
607
+ has_constraints = min_field in config or max_field in config
608
+
609
+ if not has_constraints:
610
+ # For fields without constraints, just return the value from config
611
+ return default_value
612
+
613
+ # Get min and max values from config (if they exist)
614
+ min_value = config.get(min_field, None)
615
+ max_value = config.get(max_field, None)
616
+
617
+ # If user provided options with this field
618
+ if options and field in options:
619
+ user_value = options[field]
620
+
621
+ # Apply constraints if they exist
622
+ if min_value is not None and user_value < min_value:
623
+ return min_value
624
+ if max_value is not None and user_value > max_value:
625
+ return max_value
626
+
627
+ # If within bounds, use the user's value
628
+ return user_value
629
+
630
+ # If no user value, return the default
631
+ return default_value
632
+
633
+ async def _generate_clip_prompt(self, video_id: str, title: str, description: str) -> str:
634
+ """Generate a new prompt for the next clip based on event history"""
635
+ events = self.video_events.get(video_id, [])
636
+ events_json = "\n".join(json.dumps(event) for event in events)
637
+
638
+ prompt = f"""# Context and task
639
+ Please write the caption for a new clip.
640
+
641
+ # Instructions
642
+ 1. Consider the video context and recent events
643
+ 2. Create a natural progression from previous clips
644
+ 3. Take into account user suggestions (chat messages) into the scene
645
+ 4. Don't generate hateful, political, violent or sexual content
646
+ 5. Keep visual consistency with previous clips (in most cases you should repeat the same exact description of the location, characters etc but only change a few elements. If this is a webcam scenario, don't touch the camera orientation or focus)
647
+ 6. Return ONLY the caption text, no additional formatting or explanation
648
+ 7. Write in English, about 200 words.
649
+ 8. Keep the visual style consistant, but content as well (repeat the style, character, locations, appearance etc.. across scenes, when it makes sense).
650
+ 8. Your caption must describe visual elements of the scene in details, including: camera angle and focus, people's appearance, age, look, costumes, clothes, the location visual characteristics and geometry, lighting, action, objects, weather, textures, lighting.
651
+
652
+ # Examples
653
+ Here is a demo scenario, with fake data:
654
+ {{"time": "2024-11-29T13:36:15Z", "event": "new_stream_clip", "caption": "webcam view of a beautiful park, squirrels are playing in the lush grass, blablabla etc... (rest omitted for brevity)"}}
655
+ {{"time": "2024-11-29T13:36:20Z", "event": "new_chat_message", "username": "MonkeyLover89", "data": "hi"}}
656
+ {{"time": "2024-11-29T13:36:25Z", "event": "new_chat_message", "username": "MonkeyLover89", "data": "more squirrels plz"}}
657
+ {{"time": "2024-11-29T13:36:26Z", "event": "new_stream_clip", "caption": "webcam view of a beautiful park, a lot of squirrels are playing in the lush grass, blablabla etc... (rest omitted for brevity)"}}
658
+
659
+ # Real scenario and data
660
+
661
+ We are inside a video titled "{title}"
662
+ The video is described by: "{description}".
663
+ Here is a summary of the {len(events)} most recent events:
664
+ {events_json}
665
+
666
+ # Your response
667
+ Your caption:"""
668
+
669
+ try:
670
+ response = await asyncio.get_event_loop().run_in_executor(
671
+ None,
672
+ lambda: self.inference_client.text_generation(
673
+ prompt,
674
+ model=TEXT_MODEL,
675
+ max_new_tokens=200,
676
+ temperature=0.7
677
+ )
678
+ )
679
+
680
+ # Clean up the response
681
+ caption = response.strip()
682
+ if caption.lower().startswith("caption:"):
683
+ caption = caption[8:].strip()
684
+
685
+ return caption
686
+
687
+ except Exception as e:
688
+ logger.error(f"Error generating clip prompt: {str(e)}")
689
+ # Fallback to original description if prompt generation fails
690
+ return description
691
+
692
+ async def generate_video_thumbnail(self, title: str, description: str, video_prompt_prefix: str, options: dict, user_role: UserRole = 'anon') -> str:
693
+ """
694
+ Generate a short, low-resolution video thumbnail for search results and previews.
695
+ Optimized for quick generation and low resource usage.
696
+ """
697
+ video_id = options.get('video_id', str(uuid.uuid4()))
698
+ seed = options.get('seed', generate_seed())
699
+ request_id = str(uuid.uuid4())[:8] # Generate a short ID for logging
700
+
701
+ logger.info(f"[{request_id}] Starting video thumbnail generation for video_id: {video_id}")
702
+ logger.info(f"[{request_id}] Title: '{title}', User role: {user_role}")
703
+
704
+ # Create a more concise prompt for the thumbnail
705
+ clip_caption = f"{video_prompt_prefix} - {title.strip()}"
706
+
707
+ # Add the thumbnail generation to event history
708
+ self._add_event(video_id, {
709
+ "time": datetime.datetime.utcnow().isoformat() + "Z",
710
+ "event": "thumbnail_generation",
711
+ "caption": clip_caption,
712
+ "seed": seed,
713
+ "request_id": request_id
714
+ })
715
+
716
+ # Use a shorter prompt for thumbnails
717
+ prompt = f"{clip_caption}, {POSITIVE_PROMPT_SUFFIX}"
718
+ logger.info(f"[{request_id}] Using prompt: '{prompt}'")
719
+
720
+ # Specialized configuration for thumbnails - smaller size, single frame
721
+ width = 512 # Reduced size for thumbnails
722
+ height = 288 # 16:9 aspect ratio
723
+ num_frames = THUMBNAIL_FRAMES # Just one frame for static thumbnail
724
+ num_inference_steps = 4 # Fewer steps for faster generation
725
+ frame_rate = 25 # Standard frame rate
726
+
727
+ # Optionally override with options if specified
728
+ width = options.get('width', width)
729
+ height = options.get('height', height)
730
+ num_frames = options.get('num_frames', num_frames)
731
+ num_inference_steps = options.get('num_inference_steps', num_inference_steps)
732
+ frame_rate = options.get('frame_rate', frame_rate)
733
+
734
+ logger.info(f"[{request_id}] Configuration: width={width}, height={height}, frames={num_frames}, steps={num_inference_steps}, fps={frame_rate}")
735
+
736
+ # Add thumbnail-specific tag to help debugging and metrics
737
+ options['thumbnail'] = True
738
+
739
+ # Check for available endpoints before attempting generation
740
+ available_endpoints = sum(1 for ep in self.endpoint_manager.endpoints
741
+ if not ep.busy and time.time() > ep.error_until)
742
+ logger.info(f"[{request_id}] Available endpoints: {available_endpoints}/{len(self.endpoint_manager.endpoints)}")
743
+
744
+ if available_endpoints == 0:
745
+ logger.error(f"[{request_id}] No available endpoints for thumbnail generation")
746
+ return ""
747
+
748
+ # Use the same logic as regular video generation but with thumbnail settings
749
+ try:
750
+ # logger.info(f"[{request_id}] Generating thumbnail for video {video_id} with seed {seed}")
751
+
752
+ start_time = time.time()
753
+ # Rest of thumbnail generation logic same as regular video but with optimized settings
754
+ result = await self._generate_video_content(
755
+ prompt=prompt,
756
+ negative_prompt=options.get('negative_prompt', NEGATIVE_PROMPT),
757
+ width=width,
758
+ height=height,
759
+ num_frames=num_frames,
760
+ num_inference_steps=num_inference_steps,
761
+ frame_rate=frame_rate,
762
+ seed=seed,
763
+ options=options,
764
+ user_role=user_role
765
+ )
766
+ duration = time.time() - start_time
767
+
768
+ if result:
769
+ data_length = len(result)
770
+ logger.info(f"[{request_id}] Successfully generated thumbnail in {duration:.2f}s, data length: {data_length} chars")
771
+ return result
772
+ else:
773
+ logger.error(f"[{request_id}] Empty result returned from video generation")
774
+ return ""
775
+
776
+ except Exception as e:
777
+ logger.error(f"[{request_id}] Error generating thumbnail: {e}")
778
+ if hasattr(e, "__traceback__"):
779
+ import traceback
780
+ logger.error(f"[{request_id}] Traceback: {traceback.format_exc()}")
781
+ return "" # Return empty string instead of raising to avoid crashes
782
+
783
+ async def generate_video(self, title: str, description: str, video_prompt_prefix: str, options: dict, user_role: UserRole = 'anon') -> str:
784
+ """Generate video using available space from pool"""
785
+ video_id = options.get('video_id', str(uuid.uuid4()))
786
+
787
+ # Generate a new prompt based on event history
788
+ #clip_caption = await self._generate_clip_prompt(video_id, title, description)
789
+ clip_caption = f"{video_prompt_prefix} - {title.strip()} - {description.strip()}"
790
+
791
+ # Add the new clip to event history
792
+ self._add_event(video_id, {
793
+ "time": datetime.datetime.utcnow().isoformat() + "Z",
794
+ "event": "new_stream_clip",
795
+ "caption": clip_caption
796
+ })
797
+
798
+ # Use the generated caption as the prompt
799
+ prompt = f"{clip_caption}, {POSITIVE_PROMPT_SUFFIX}"
800
+
801
+ # Get the config values based on user role
802
+ width = self.get_config_value(user_role, 'clip_width', options)
803
+ height = self.get_config_value(user_role, 'clip_height', options)
804
+ num_frames = self.get_config_value(user_role, 'num_frames', options)
805
+ num_inference_steps = self.get_config_value(user_role, 'num_inference_steps', options)
806
+ frame_rate = self.get_config_value(user_role, 'clip_framerate', options)
807
+
808
+ # Get orientation from options
809
+ orientation = options.get('orientation', 'LANDSCAPE')
810
+
811
+ # Adjust width and height based on orientation if needed
812
+ if orientation == 'PORTRAIT' and width > height:
813
+ # Swap width and height for portrait orientation
814
+ width, height = height, width
815
+ # logger.info(f"Orientation: {orientation}, swapped dimensions to width={width}, height={height}")
816
+ elif orientation == 'LANDSCAPE' and height > width:
817
+ # Swap height and width for landscape orientation
818
+ height, width = width, height
819
+ # logger.info(f"generate_video() Orientation: {orientation}, swapped dimensions to width={width}, height={height}, steps={num_inference_steps}, fps={frame_rate} | role: {user_role}")
820
+ else:
821
+ # logger.info(f"generate_video() Orientation: {orientation}, using original dimensions width={width}, height={height}, steps={num_inference_steps}, fps={frame_rate} | role: {user_role}")
822
+ pass
823
+
824
+ # Generate the video with standard settings
825
+ return await self._generate_video_content(
826
+ prompt=prompt,
827
+ negative_prompt=options.get('negative_prompt', NEGATIVE_PROMPT),
828
+ width=width,
829
+ height=height,
830
+ num_frames=num_frames,
831
+ num_inference_steps=num_inference_steps,
832
+ frame_rate=frame_rate,
833
+ seed=options.get('seed', 42),
834
+ options=options,
835
+ user_role=user_role
836
+ )
837
+
838
+ async def _generate_video_content(self, prompt: str, negative_prompt: str, width: int,
839
+ height: int, num_frames: int, num_inference_steps: int,
840
+ frame_rate: int, seed: int, options: dict, user_role: UserRole) -> str:
841
+ """
842
+ Internal method to generate video content with specific parameters.
843
+ Used by both regular video generation and thumbnail generation.
844
+ """
845
+ is_thumbnail = options.get('thumbnail', False)
846
+ request_id = options.get('request_id', str(uuid.uuid4())[:8]) # Get or generate request ID
847
+ video_id = options.get('video_id', 'unknown')
848
+
849
+ # logger.info(f"[{request_id}] Generating {'thumbnail' if is_thumbnail else 'video'} for video {video_id} with seed {seed}")
850
+
851
+ json_payload = {
852
+ "inputs": {
853
+ "prompt": prompt,
854
+ },
855
+ "parameters": {
856
+ # ------------------- settings for LTX-Video -----------------------
857
+ "negative_prompt": negative_prompt,
858
+ "width": width,
859
+ "height": height,
860
+ "num_frames": num_frames,
861
+ "num_inference_steps": num_inference_steps,
862
+ "guidance_scale": options.get('guidance_scale', GUIDANCE_SCALE),
863
+ "seed": seed,
864
+
865
+ # ------------------- settings for Varnish -----------------------
866
+ "double_num_frames": False, # <- False for real-time generation
867
+ "fps": frame_rate,
868
+ "super_resolution": False, # <- False for real-time generation
869
+ "grain_amount": 0, # No film grain (on low-res, low-quality generation the effects aren't worth it + it adds weight to the MP4 payload)
870
+ }
871
+ }
872
+
873
+ # Add thumbnail flag to help with metrics and debugging
874
+ if is_thumbnail:
875
+ json_payload["metadata"] = {
876
+ "is_thumbnail": True,
877
+ "thumbnail_version": "1.0",
878
+ "request_id": request_id
879
+ }
880
+
881
+ # logger.info(f"[{request_id}] Waiting for an available endpoint...")
882
+ async with self.endpoint_manager.get_endpoint() as endpoint:
883
+ # logger.info(f"[{request_id}] Using endpoint {endpoint.id} for generation")
884
+
885
+ try:
886
+ async with ClientSession() as session:
887
+ #logger.info(f"[{request_id}] Sending request to endpoint {endpoint.id}: {endpoint.url}")
888
+ start_time = time.time()
889
+
890
+ # Proceed with actual request
891
+ async with session.post(
892
+ endpoint.url,
893
+ headers={
894
+ "Accept": "application/json",
895
+ "Authorization": f"Bearer {HF_TOKEN}",
896
+ "Content-Type": "application/json",
897
+ "X-Request-ID": request_id # Add request ID to headers
898
+ },
899
+ json=json_payload,
900
+ timeout=12 # Extended timeout for thumbnails (was 8s)
901
+ ) as response:
902
+ request_duration = time.time() - start_time
903
+ #logger.info(f"[{request_id}] Received response from endpoint {endpoint.id} in {request_duration:.2f}s: HTTP {response.status}")
904
+
905
+ if response.status != 200:
906
+ error_text = await response.text()
907
+ logger.error(f"[{request_id}] Failed response: {error_text}")
908
+ # Mark endpoint as in error state
909
+ await self._mark_endpoint_error(endpoint)
910
+ if "paused" in error_text:
911
+ logger.error(f"[{request_id}] Endpoint is paused")
912
+ return ""
913
+ raise Exception(f"Video generation failed: HTTP {response.status} - {error_text}")
914
+
915
+ result = await response.json()
916
+ #logger.info(f"[{request_id}] Successfully parsed JSON response")
917
+
918
+ if "error" in result:
919
+ error_msg = result['error']
920
+ logger.error(f"[{request_id}] Error in response: {error_msg}")
921
+ # Mark endpoint as in error state
922
+ await self._mark_endpoint_error(endpoint)
923
+ if "paused" in str(error_msg).lower():
924
+ logger.error(f"[{request_id}] Endpoint is paused")
925
+ return ""
926
+ raise Exception(f"Video generation failed: {error_msg}")
927
+
928
+ video_data_uri = result.get("video")
929
+ if not video_data_uri:
930
+ logger.error(f"[{request_id}] No video data in response")
931
+ # Mark endpoint as in error state
932
+ await self._mark_endpoint_error(endpoint)
933
+ raise Exception("No video data in response")
934
+
935
+ # Get data size
936
+ data_size = len(video_data_uri)
937
+ #logger.info(f"[{request_id}] Received video data: {data_size} chars")
938
+
939
+ # Reset error count on successful call
940
+ endpoint.error_count = 0
941
+ endpoint.error_until = 0
942
+
943
+ return video_data_uri
944
+
945
+ except asyncio.TimeoutError:
946
+ # Handle timeout specifically
947
+ logger.error(f"[{request_id}] Timeout occurred after {time.time() - start_time:.2f}s")
948
+ await self._mark_endpoint_error(endpoint, is_timeout=True)
949
+ return ""
950
+ except Exception as e:
951
+ # Handle all other exceptions
952
+ logger.error(f"[{request_id}] Exception during video generation: {str(e)}")
953
+ if not isinstance(e, asyncio.TimeoutError): # Already handled above
954
+ await self._mark_endpoint_error(endpoint)
955
+ return ""
956
+
957
+ async def _mark_endpoint_error(self, endpoint: Endpoint, is_timeout: bool = False):
958
+ """Mark an endpoint as being in error state with exponential backoff"""
959
+ async with self.endpoint_manager.lock:
960
+ endpoint.error_count += 1
961
+
962
+ # Calculate backoff time exponentially based on error count
963
+ # Start with 15 seconds, then 30, 60, etc. up to a max of 5 minutes
964
+ # Using shorter backoffs since generation should be fast
965
+ backoff_seconds = min(15 * (2 ** (endpoint.error_count - 1)), 300)
966
+
967
+ # Add extra backoff for timeouts which are more indicative of serious issues
968
+ if is_timeout:
969
+ backoff_seconds *= 2
970
+
971
+ endpoint.error_until = time.time() + backoff_seconds
972
+
973
+ logger.warning(
974
+ f"Endpoint {endpoint.id} marked as in error state (count: {endpoint.error_count}, "
975
+ f"unavailable until: {datetime.datetime.fromtimestamp(endpoint.error_until).strftime('%H:%M:%S')})"
976
+ )
977
+
978
+
979
+ async def handle_chat_message(self, data: dict, ws: web.WebSocketResponse) -> dict:
980
+ """Process and broadcast a chat message"""
981
+ video_id = data.get('videoId')
982
+ request_id = data.get('requestId')
983
+
984
+ if not video_id:
985
+ return {
986
+ 'action': 'chat_message',
987
+ 'requestId': request_id,
988
+ 'success': False,
989
+ 'error': 'No video ID provided'
990
+ }
991
+
992
+ # Add chat message to event history
993
+ self._add_event(video_id, {
994
+ "time": datetime.datetime.utcnow().isoformat() + "Z",
995
+ "event": "new_chat_message",
996
+ "username": data.get('username', 'Anonymous'),
997
+ "data": data.get('content', '')
998
+ })
999
+
1000
+ room = self.chat_rooms[video_id]
1001
+ message_data = {k: v for k, v in data.items() if k != '_ws'}
1002
+ room.add_message(message_data)
1003
+
1004
+ for client in room.connected_clients:
1005
+ if client != ws:
1006
+ try:
1007
+ await client.send_json({
1008
+ 'action': 'chat_message',
1009
+ 'broadcast': True,
1010
+ **message_data
1011
+ })
1012
+ except Exception as e:
1013
+ logger.error(f"Failed to broadcast to client: {e}")
1014
+ room.connected_clients.remove(client)
1015
+
1016
+ return {
1017
+ 'action': 'chat_message',
1018
+ 'requestId': request_id,
1019
+ 'success': True,
1020
+ 'message': message_data
1021
+ }
1022
+
1023
+ async def handle_join_chat(self, data: dict, ws: web.WebSocketResponse) -> dict:
1024
+ """Handle a request to join a chat room"""
1025
+ video_id = data.get('videoId')
1026
+ request_id = data.get('requestId')
1027
+
1028
+ if not video_id:
1029
+ return {
1030
+ 'action': 'join_chat',
1031
+ 'requestId': request_id,
1032
+ 'success': False,
1033
+ 'error': 'No video ID provided'
1034
+ }
1035
+
1036
+ room = self.chat_rooms[video_id]
1037
+ room.connected_clients.add(ws)
1038
+ recent_messages = room.get_recent_messages()
1039
+
1040
+ return {
1041
+ 'action': 'join_chat',
1042
+ 'requestId': request_id,
1043
+ 'success': True,
1044
+ 'messages': recent_messages
1045
+ }
1046
+
1047
+ async def handle_leave_chat(self, data: dict, ws: web.WebSocketResponse) -> dict:
1048
+ """Handle a request to leave a chat room"""
1049
+ video_id = data.get('videoId')
1050
+ request_id = data.get('requestId')
1051
+
1052
+ if not video_id:
1053
+ return {
1054
+ 'action': 'leave_chat',
1055
+ 'requestId': request_id,
1056
+ 'success': False,
1057
+ 'error': 'No video ID provided'
1058
+ }
1059
+
1060
+ room = self.chat_rooms[video_id]
1061
+ if ws in room.connected_clients:
1062
+ room.connected_clients.remove(ws)
1063
+
1064
+ return {
1065
+ 'action': 'leave_chat',
1066
+ 'requestId': request_id,
1067
+ 'success': True
1068
+ }
reference_example/api_metrics.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import asyncio
4
+ from collections import defaultdict
5
+ from typing import Dict, List, Set, Optional
6
+ import datetime
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class MetricsTracker:
11
+ """
12
+ Tracks usage metrics across the API server.
13
+ """
14
+ def __init__(self):
15
+ # Total metrics since server start
16
+ self.total_requests = {
17
+ 'chat': 0,
18
+ 'video': 0,
19
+ 'search': 0,
20
+ 'other': 0,
21
+ }
22
+
23
+ # Per-user metrics
24
+ self.user_metrics = defaultdict(lambda: {
25
+ 'requests': {
26
+ 'chat': 0,
27
+ 'video': 0,
28
+ 'search': 0,
29
+ 'other': 0,
30
+ },
31
+ 'first_seen': time.time(),
32
+ 'last_active': time.time(),
33
+ 'role': 'anon'
34
+ })
35
+
36
+ # Rate limiting buckets (per minute)
37
+ self.rate_limits = {
38
+ 'anon': {
39
+ 'video': 30,
40
+ 'search': 45,
41
+ 'chat': 90,
42
+ 'other': 45
43
+ },
44
+ 'normal': {
45
+ 'video': 60,
46
+ 'search': 90,
47
+ 'chat': 180,
48
+ 'other': 90
49
+ },
50
+ 'pro': {
51
+ 'video': 120,
52
+ 'search': 180,
53
+ 'chat': 300,
54
+ 'other': 180
55
+ },
56
+ 'admin': {
57
+ 'video': 240,
58
+ 'search': 360,
59
+ 'chat': 450,
60
+ 'other': 360
61
+ }
62
+ }
63
+
64
+ # Minute-based rate limiting buckets
65
+ self.time_buckets = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
66
+
67
+ # Lock for thread safety
68
+ self.lock = asyncio.Lock()
69
+
70
+ # Track concurrent sessions by IP
71
+ self.ip_sessions = defaultdict(set)
72
+
73
+ # Server start time
74
+ self.start_time = time.time()
75
+
76
+ async def record_request(self, user_id: str, ip: str, request_type: str, role: str):
77
+ """Record a request for metrics and rate limiting"""
78
+ async with self.lock:
79
+ # Update total metrics
80
+ if request_type in self.total_requests:
81
+ self.total_requests[request_type] += 1
82
+ else:
83
+ self.total_requests['other'] += 1
84
+
85
+ # Update user metrics
86
+ user_data = self.user_metrics[user_id]
87
+ user_data['last_active'] = time.time()
88
+ user_data['role'] = role
89
+
90
+ if request_type in user_data['requests']:
91
+ user_data['requests'][request_type] += 1
92
+ else:
93
+ user_data['requests']['other'] += 1
94
+
95
+ # Update time bucket for rate limiting
96
+ current_minute = int(time.time() / 60)
97
+ self.time_buckets[user_id][current_minute][request_type] += 1
98
+
99
+ # Clean up old time buckets (keep only last 10 minutes)
100
+ cutoff = current_minute - 10
101
+ for minute in list(self.time_buckets[user_id].keys()):
102
+ if minute < cutoff:
103
+ del self.time_buckets[user_id][minute]
104
+
105
+ def register_session(self, user_id: str, ip: str):
106
+ """Register a new session for an IP address"""
107
+ self.ip_sessions[ip].add(user_id)
108
+
109
+ def unregister_session(self, user_id: str, ip: str):
110
+ """Unregister a session when it disconnects"""
111
+ if user_id in self.ip_sessions[ip]:
112
+ self.ip_sessions[ip].remove(user_id)
113
+ if not self.ip_sessions[ip]:
114
+ del self.ip_sessions[ip]
115
+
116
+ def get_session_count_for_ip(self, ip: str) -> int:
117
+ """Get the number of active sessions for an IP address"""
118
+ return len(self.ip_sessions.get(ip, set()))
119
+
120
+ async def is_rate_limited(self, user_id: str, request_type: str, role: str) -> bool:
121
+ """Check if a user is currently rate limited for a request type"""
122
+ async with self.lock:
123
+ current_minute = int(time.time() / 60)
124
+ prev_minute = current_minute - 1
125
+
126
+ # Count requests in current and previous minute
127
+ current_count = self.time_buckets[user_id][current_minute][request_type]
128
+ prev_count = self.time_buckets[user_id][prev_minute][request_type]
129
+
130
+ # Calculate requests per minute rate (weighted average)
131
+ # Weight current minute more as it's more recent
132
+ rate = (current_count * 0.7) + (prev_count * 0.3)
133
+
134
+ # Get rate limit based on user role
135
+ limit = self.rate_limits.get(role, self.rate_limits['anon']).get(
136
+ request_type, self.rate_limits['anon']['other'])
137
+
138
+ # Check if rate exceeds limit
139
+ return rate >= limit
140
+
141
+ def get_metrics(self) -> Dict:
142
+ """Get a snapshot of current metrics"""
143
+ active_users = {
144
+ 'total': len(self.user_metrics),
145
+ 'anon': 0,
146
+ 'normal': 0,
147
+ 'pro': 0,
148
+ 'admin': 0,
149
+ }
150
+
151
+ # Count active users in the last 5 minutes
152
+ active_cutoff = time.time() - (5 * 60)
153
+ for user_data in self.user_metrics.values():
154
+ if user_data['last_active'] >= active_cutoff:
155
+ active_users[user_data['role']] += 1
156
+
157
+ return {
158
+ 'uptime_seconds': int(time.time() - self.start_time),
159
+ 'total_requests': dict(self.total_requests),
160
+ 'active_users': active_users,
161
+ 'active_ips': len(self.ip_sessions),
162
+ 'timestamp': datetime.datetime.now().isoformat()
163
+ }
164
+
165
+ def get_detailed_metrics(self) -> Dict:
166
+ """Get detailed metrics including per-user data"""
167
+ metrics = self.get_metrics()
168
+
169
+ # Add anonymized user metrics
170
+ user_list = []
171
+ for user_id, data in self.user_metrics.items():
172
+ # Skip users inactive for more than 1 hour
173
+ if time.time() - data['last_active'] > 3600:
174
+ continue
175
+
176
+ user_list.append({
177
+ 'id': user_id[:8] + '...', # Anonymize ID
178
+ 'role': data['role'],
179
+ 'requests': data['requests'],
180
+ 'active_ago': int(time.time() - data['last_active']),
181
+ 'session_duration': int(time.time() - data['first_seen'])
182
+ })
183
+
184
+ metrics['users'] = user_list
185
+ return metrics
reference_example/api_session.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from typing import Dict, Set
4
+ from aiohttp import web, WSMsgType
5
+ import json
6
+ import time
7
+ import datetime
8
+ from api_core import VideoGenerationAPI
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class UserSession:
13
+ """
14
+ Represents a user's session with the API.
15
+ Each WebSocket connection gets its own session with separate queues and rate limits.
16
+ """
17
+ def __init__(self, user_id: str, user_role: str, ws: web.WebSocketResponse, shared_api):
18
+ self.user_id = user_id
19
+ self.user_role = user_role
20
+ self.ws = ws
21
+ self.shared_api = shared_api # For shared resources like endpoint manager
22
+
23
+ # Create separate queues for this user session
24
+ self.chat_queue = asyncio.Queue()
25
+ self.video_queue = asyncio.Queue()
26
+ self.search_queue = asyncio.Queue()
27
+ self.simulation_queue = asyncio.Queue() # New queue for description evolution
28
+
29
+ # Track request counts and rate limits
30
+ self.request_counts = {
31
+ 'chat': 0,
32
+ 'video': 0,
33
+ 'search': 0,
34
+ 'simulation': 0 # New counter for simulation requests
35
+ }
36
+
37
+ # Last request timestamps for rate limiting
38
+ self.last_request_times = {
39
+ 'chat': time.time(),
40
+ 'video': time.time(),
41
+ 'search': time.time(),
42
+ 'simulation': time.time() # New timestamp for simulation requests
43
+ }
44
+
45
+ # Session creation time
46
+ self.created_at = time.time()
47
+
48
+ self.background_tasks = []
49
+
50
+ async def start(self):
51
+ """Start all the queue processors for this session"""
52
+ # Start background tasks for handling different request types
53
+ self.background_tasks = [
54
+ asyncio.create_task(self._process_chat_queue()),
55
+ asyncio.create_task(self._process_video_queue()),
56
+ asyncio.create_task(self._process_search_queue()),
57
+ asyncio.create_task(self._process_simulation_queue()) # New worker for simulation requests
58
+ ]
59
+ logger.info(f"Started session for user {self.user_id} with role {self.user_role}")
60
+
61
+ async def stop(self):
62
+ """Stop all background tasks for this session"""
63
+ for task in self.background_tasks:
64
+ task.cancel()
65
+
66
+ try:
67
+ # Wait for tasks to complete cancellation
68
+ await asyncio.gather(*self.background_tasks, return_exceptions=True)
69
+ except asyncio.CancelledError:
70
+ pass
71
+
72
+ logger.info(f"Stopped session for user {self.user_id}")
73
+
74
+ async def _process_chat_queue(self):
75
+ """High priority queue for chat operations"""
76
+ while True:
77
+ data = await self.chat_queue.get()
78
+ try:
79
+ if data['action'] == 'join_chat':
80
+ result = await self.shared_api.handle_join_chat(data, self.ws)
81
+ elif data['action'] == 'chat_message':
82
+ result = await self.shared_api.handle_chat_message(data, self.ws)
83
+ elif data['action'] == 'leave_chat':
84
+ result = await self.shared_api.handle_leave_chat(data, self.ws)
85
+ # Redirect thumbnail requests to process_generic_request for consistent handling
86
+ elif data['action'] == 'generate_video_thumbnail':
87
+ # Pass to the generic request handler to maintain consistent logic
88
+ await self.process_generic_request(data)
89
+ # Skip normal response handling since process_generic_request already sends a response
90
+ self.chat_queue.task_done()
91
+ continue
92
+ else:
93
+ raise ValueError(f"Unknown chat action: {data['action']}")
94
+
95
+ await self.ws.send_json(result)
96
+
97
+ # Update metrics
98
+ self.request_counts['chat'] += 1
99
+ self.last_request_times['chat'] = time.time()
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error processing chat request for user {self.user_id}: {e}")
103
+ try:
104
+ await self.ws.send_json({
105
+ 'action': data['action'],
106
+ 'requestId': data.get('requestId'),
107
+ 'success': False,
108
+ 'error': f'Chat error: {str(e)}'
109
+ })
110
+ except Exception as send_error:
111
+ logger.error(f"Error sending error response: {send_error}")
112
+ finally:
113
+ self.chat_queue.task_done()
114
+
115
+ async def _process_video_queue(self):
116
+ """Process multiple video generation requests in parallel for this user"""
117
+ from api_config import VIDEO_ROUND_ROBIN_ENDPOINT_URLS
118
+
119
+ active_tasks = set()
120
+ # Set a per-user concurrent limit based on role
121
+ max_concurrent = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS)
122
+ if self.user_role == 'anon':
123
+ max_concurrent = min(2, max_concurrent) # Limit anonymous users
124
+ elif self.user_role == 'normal':
125
+ max_concurrent = min(4, max_concurrent) # Standard users
126
+ # Pro and admin can use all endpoints
127
+
128
+ async def process_single_request(data):
129
+ try:
130
+ title = data.get('title', '')
131
+ description = data.get('description', '')
132
+ video_prompt_prefix = data.get('video_prompt_prefix', '')
133
+ options = data.get('options', {})
134
+
135
+ # Pass the user role to generate_video
136
+ video_data = await self.shared_api.generate_video(
137
+ title, description, video_prompt_prefix, options, self.user_role
138
+ )
139
+
140
+ result = {
141
+ 'action': 'generate_video',
142
+ 'requestId': data.get('requestId'),
143
+ 'success': True,
144
+ 'video': video_data,
145
+ }
146
+
147
+ await self.ws.send_json(result)
148
+
149
+ # Update metrics
150
+ self.request_counts['video'] += 1
151
+ self.last_request_times['video'] = time.time()
152
+
153
+ except Exception as e:
154
+ logger.error(f"Error processing video request for user {self.user_id}: {e}")
155
+ try:
156
+ await self.ws.send_json({
157
+ 'action': 'generate_video',
158
+ 'requestId': data.get('requestId'),
159
+ 'success': False,
160
+ 'error': f'Video generation error: {str(e)}'
161
+ })
162
+ except Exception as send_error:
163
+ logger.error(f"Error sending error response: {send_error}")
164
+ finally:
165
+ active_tasks.discard(asyncio.current_task())
166
+
167
+ while True:
168
+ # Clean up completed tasks
169
+ active_tasks = {task for task in active_tasks if not task.done()}
170
+
171
+ # Start new tasks if we have capacity
172
+ while len(active_tasks) < max_concurrent:
173
+ try:
174
+ # Use try_get to avoid blocking if queue is empty
175
+ data = await asyncio.wait_for(self.video_queue.get(), timeout=0.1)
176
+
177
+ # Create and start new task
178
+ task = asyncio.create_task(process_single_request(data))
179
+ active_tasks.add(task)
180
+
181
+ except asyncio.TimeoutError:
182
+ # No items in queue, break inner loop
183
+ break
184
+ except Exception as e:
185
+ logger.error(f"Error creating video generation task for user {self.user_id}: {e}")
186
+ break
187
+
188
+ # Wait a short time before checking queue again
189
+ await asyncio.sleep(0.1)
190
+
191
+ # Handle any completed tasks' errors
192
+ for task in list(active_tasks):
193
+ if task.done():
194
+ try:
195
+ await task
196
+ except Exception as e:
197
+ logger.error(f"Task failed with error for user {self.user_id}: {e}")
198
+ active_tasks.discard(task)
199
+
200
+ async def _process_search_queue(self):
201
+ """Medium priority queue for search operations"""
202
+ while True:
203
+ try:
204
+ data = await self.search_queue.get()
205
+ request_id = data.get('requestId')
206
+ query = data.get('query', '').strip()
207
+ attempt_count = data.get('attemptCount', 0)
208
+
209
+ # logger.info(f"Processing search request for user {self.user_id}, attempt={attempt_count}")
210
+
211
+ if not query:
212
+ logger.warning(f"Empty query received in request from user {self.user_id}: {data}")
213
+ result = {
214
+ 'action': 'search',
215
+ 'requestId': request_id,
216
+ 'success': False,
217
+ 'error': 'No search query provided'
218
+ }
219
+ else:
220
+ try:
221
+ search_result = await self.shared_api.search_video(
222
+ query,
223
+ attempt_count=attempt_count
224
+ )
225
+
226
+ if search_result:
227
+ # logger.info(f"Search successful for user {self.user_id}, query '{query}'")
228
+ result = {
229
+ 'action': 'search',
230
+ 'requestId': request_id,
231
+ 'success': True,
232
+ 'result': search_result
233
+ }
234
+ else:
235
+ # logger.warning(f"No results found for user {self.user_id}, query '{query}'")
236
+ result = {
237
+ 'action': 'search',
238
+ 'requestId': request_id,
239
+ 'success': False,
240
+ 'error': 'No results found'
241
+ }
242
+ except Exception as e:
243
+ logger.error(f"Search error for user {self.user_id}, (attempt {attempt_count}): {str(e)}")
244
+ result = {
245
+ 'action': 'search',
246
+ 'requestId': request_id,
247
+ 'success': False,
248
+ 'error': f'Search error: {str(e)}'
249
+ }
250
+
251
+ await self.ws.send_json(result)
252
+
253
+ # Update metrics
254
+ self.request_counts['search'] += 1
255
+ self.last_request_times['search'] = time.time()
256
+
257
+ except Exception as e:
258
+ logger.error(f"Error in search queue processor for user {self.user_id}: {str(e)}")
259
+ try:
260
+ error_response = {
261
+ 'action': 'search',
262
+ 'requestId': data.get('requestId') if 'data' in locals() else None,
263
+ 'success': False,
264
+ 'error': f'Internal server error: {str(e)}'
265
+ }
266
+ await self.ws.send_json(error_response)
267
+ except Exception as send_error:
268
+ logger.error(f"Error sending error response: {send_error}")
269
+ finally:
270
+ if 'search_queue' in self.__dict__:
271
+ self.search_queue.task_done()
272
+
273
+ async def _process_simulation_queue(self):
274
+ """Dedicated queue for video simulation requests"""
275
+ while True:
276
+ try:
277
+ data = await self.simulation_queue.get()
278
+ request_id = data.get('requestId')
279
+
280
+ # Extract parameters from the request
281
+ video_id = data.get('video_id', '')
282
+ original_title = data.get('original_title', '')
283
+ original_description = data.get('original_description', '')
284
+ current_description = data.get('current_description', '')
285
+ condensed_history = data.get('condensed_history', '')
286
+ evolution_count = data.get('evolution_count', 0)
287
+ chat_messages = data.get('chat_messages', '')
288
+
289
+ logger.info(f"Processing video simulation for user {self.user_id}, video_id={video_id}, evolution_count={evolution_count}")
290
+
291
+ # Validate required parameters
292
+ if not original_title or not original_description or not current_description:
293
+ result = {
294
+ 'action': 'simulate',
295
+ 'requestId': request_id,
296
+ 'success': False,
297
+ 'error': 'Missing required parameters'
298
+ }
299
+ else:
300
+ try:
301
+ # Call the simulate method in the API
302
+ simulation_result = await self.shared_api.simulate(
303
+ original_title=original_title,
304
+ original_description=original_description,
305
+ current_description=current_description,
306
+ condensed_history=condensed_history,
307
+ evolution_count=evolution_count,
308
+ chat_messages=chat_messages
309
+ )
310
+
311
+ result = {
312
+ 'action': 'simulate',
313
+ 'requestId': request_id,
314
+ 'success': True,
315
+ 'evolved_description': simulation_result['evolved_description'],
316
+ 'condensed_history': simulation_result['condensed_history']
317
+ }
318
+ except Exception as e:
319
+ logger.error(f"Error simulating video for user {self.user_id}, video_id={video_id}: {str(e)}")
320
+ result = {
321
+ 'action': 'simulate',
322
+ 'requestId': request_id,
323
+ 'success': False,
324
+ 'error': f'Simulation error: {str(e)}'
325
+ }
326
+
327
+ await self.ws.send_json(result)
328
+
329
+ # Update metrics
330
+ self.request_counts['simulation'] += 1
331
+ self.last_request_times['simulation'] = time.time()
332
+
333
+ except Exception as e:
334
+ logger.error(f"Error in simulation queue processor for user {self.user_id}: {str(e)}")
335
+ try:
336
+ error_response = {
337
+ 'action': 'simulate',
338
+ 'requestId': data.get('requestId') if 'data' in locals() else None,
339
+ 'success': False,
340
+ 'error': f'Internal server error: {str(e)}'
341
+ }
342
+ await self.ws.send_json(error_response)
343
+ except Exception as send_error:
344
+ logger.error(f"Error sending error response: {send_error}")
345
+ finally:
346
+ if 'simulation_queue' in self.__dict__:
347
+ self.simulation_queue.task_done()
348
+
349
+ async def process_generic_request(self, data: dict) -> None:
350
+ """Handle general requests that don't fit into specialized queues"""
351
+ try:
352
+ request_id = data.get('requestId')
353
+ action = data.get('action')
354
+
355
+ def error_response(message: str):
356
+ return {
357
+ 'action': action,
358
+ 'requestId': request_id,
359
+ 'success': False,
360
+ 'error': message
361
+ }
362
+
363
+ if action == 'heartbeat':
364
+ # Include user role info in heartbeat response
365
+ await self.ws.send_json({
366
+ 'action': 'heartbeat',
367
+ 'requestId': request_id,
368
+ 'success': True,
369
+ 'user_role': self.user_role
370
+ })
371
+
372
+ elif action == 'get_user_role':
373
+ # Return the user role information
374
+ await self.ws.send_json({
375
+ 'action': 'get_user_role',
376
+ 'requestId': request_id,
377
+ 'success': True,
378
+ 'user_role': self.user_role
379
+ })
380
+
381
+ elif action == 'generate_caption':
382
+ title = data.get('params', {}).get('title')
383
+ description = data.get('params', {}).get('description')
384
+
385
+ if not title or not description:
386
+ await self.ws.send_json(error_response('Missing title or description'))
387
+ return
388
+
389
+ caption = await self.shared_api.generate_caption(title, description)
390
+ await self.ws.send_json({
391
+ 'action': action,
392
+ 'requestId': request_id,
393
+ 'success': True,
394
+ 'caption': caption
395
+ })
396
+
397
+ # evolve_description is now handled by the dedicated simulation queue processor
398
+
399
+ elif action == 'generate_video_thumbnail':
400
+ title = data.get('title', '') or data.get('params', {}).get('title', '')
401
+ description = data.get('description', '') or data.get('params', {}).get('description', '')
402
+ video_prompt_prefix = data.get('video_prompt_prefix', '') or data.get('params', {}).get('video_prompt_prefix', '')
403
+ options = data.get('options', {}) or data.get('params', {}).get('options', {})
404
+
405
+ if not title:
406
+ await self.ws.send_json(error_response('Missing title for thumbnail generation'))
407
+ return
408
+
409
+ # Ensure the options include the thumbnail flag
410
+ options['thumbnail'] = True
411
+
412
+ # Prioritize thumbnail generation with higher priority
413
+ options['priority'] = 'high'
414
+
415
+ # Add small size settings if not already specified
416
+ if 'width' not in options:
417
+ options['width'] = 512 # Default thumbnail width
418
+ if 'height' not in options:
419
+ options['height'] = 288 # Default 16:9 aspect ratio
420
+ if 'num_frames' not in options:
421
+ options['num_frames'] = 25 # 1 second @ 25fps
422
+
423
+ # Let the API know this is a thumbnail for a specific video
424
+ options['video_id'] = data.get('video_id', f"thumbnail-{request_id}")
425
+
426
+ logger.info(f"Generating thumbnail for video {options['video_id']} for user {self.user_id}")
427
+
428
+ try:
429
+ # Generate the thumbnail
430
+ thumbnail_data = await self.shared_api.generate_video_thumbnail(
431
+ title, description, video_prompt_prefix, options, self.user_role
432
+ )
433
+
434
+ # Respond with appropriate format based on the parameter names used in the request
435
+ if 'thumbnailUrl' in data or 'thumbnailUrl' in data.get('params', {}):
436
+ # Legacy format using thumbnailUrl
437
+ await self.ws.send_json({
438
+ 'action': action,
439
+ 'requestId': request_id,
440
+ 'success': True,
441
+ 'thumbnailUrl': thumbnail_data or "",
442
+ })
443
+ else:
444
+ # New format using thumbnail
445
+ await self.ws.send_json({
446
+ 'action': action,
447
+ 'requestId': request_id,
448
+ 'success': True,
449
+ 'thumbnail': thumbnail_data,
450
+ })
451
+ except Exception as e:
452
+ logger.error(f"Error generating thumbnail: {str(e)}")
453
+ await self.ws.send_json(error_response(f"Thumbnail generation failed: {str(e)}"))
454
+
455
+ # Handle deprecated thumbnail actions
456
+ elif action == 'generate_thumbnail' or action == 'old_generate_thumbnail':
457
+ # Redirect to video thumbnail generation
458
+ logger.warning(f"Deprecated thumbnail action '{action}' used, redirecting to generate_video_thumbnail")
459
+
460
+ # Extract parameters
461
+ title = data.get('title', '') or data.get('params', {}).get('title', '')
462
+ description = data.get('description', '') or data.get('params', {}).get('description', '')
463
+
464
+ if not title or not description:
465
+ await self.ws.send_json(error_response('Missing title or description'))
466
+ return
467
+
468
+ # Create a new request with the correct action
469
+ new_request = {
470
+ 'action': 'generate_video_thumbnail',
471
+ 'requestId': request_id,
472
+ 'title': title,
473
+ 'description': description,
474
+ 'options': {
475
+ 'width': 512,
476
+ 'height': 288,
477
+ 'thumbnail': True,
478
+ 'video_id': f"thumbnail-{request_id}"
479
+ }
480
+ }
481
+
482
+ # Process with the new action
483
+ await self.process_generic_request(new_request)
484
+
485
+ else:
486
+ await self.ws.send_json(error_response(f'Unknown action: {action}'))
487
+
488
+ except Exception as e:
489
+ logger.error(f"Error processing generic request for user {self.user_id}: {str(e)}")
490
+ try:
491
+ await self.ws.send_json({
492
+ 'action': data.get('action'),
493
+ 'requestId': data.get('requestId'),
494
+ 'success': False,
495
+ 'error': f'Internal server error: {str(e)}'
496
+ })
497
+ except Exception as send_error:
498
+ logger.error(f"Error sending error response: {send_error}")
499
+
500
+ class SessionManager:
501
+ """
502
+ Manages all active user sessions and shared resources.
503
+ """
504
+ def __init__(self):
505
+ self.sessions = {}
506
+ self.shared_api = VideoGenerationAPI() # Single instance for shared resources
507
+ self.session_lock = asyncio.Lock()
508
+
509
+ async def create_session(self, user_id: str, user_role: str, ws: web.WebSocketResponse) -> UserSession:
510
+ """Create a new user session"""
511
+ async with self.session_lock:
512
+ # Create a new session for this user
513
+ session = UserSession(user_id, user_role, ws, self.shared_api)
514
+ await session.start()
515
+ self.sessions[user_id] = session
516
+ return session
517
+
518
+ async def delete_session(self, user_id: str) -> None:
519
+ """Delete a user session and clean up resources"""
520
+ async with self.session_lock:
521
+ if user_id in self.sessions:
522
+ session = self.sessions[user_id]
523
+ await session.stop()
524
+ del self.sessions[user_id]
525
+ logger.info(f"Deleted session for user {user_id}")
526
+
527
+ def get_session(self, user_id: str) -> UserSession:
528
+ """Get a user session if it exists"""
529
+ return self.sessions.get(user_id)
530
+
531
+ async def close_all_sessions(self) -> None:
532
+ """Close all active sessions (used during shutdown)"""
533
+ async with self.session_lock:
534
+ for user_id, session in list(self.sessions.items()):
535
+ await session.stop()
536
+ self.sessions.clear()
537
+ logger.info("Closed all active sessions")
538
+
539
+ @property
540
+ def session_count(self) -> int:
541
+ """Get the number of active sessions"""
542
+ return len(self.sessions)
543
+
544
+ def get_session_stats(self) -> Dict:
545
+ """Get statistics about active sessions"""
546
+ stats = {
547
+ 'total_sessions': len(self.sessions),
548
+ 'by_role': {
549
+ 'anon': 0,
550
+ 'normal': 0,
551
+ 'pro': 0,
552
+ 'admin': 0
553
+ },
554
+ 'requests': {
555
+ 'chat': 0,
556
+ 'video': 0,
557
+ 'search': 0,
558
+ 'simulation': 0
559
+ }
560
+ }
561
+
562
+ for session in self.sessions.values():
563
+ stats['by_role'][session.user_role] += 1
564
+ stats['requests']['chat'] += session.request_counts['chat']
565
+ stats['requests']['video'] += session.request_counts['video']
566
+ stats['requests']['search'] += session.request_counts['search']
567
+ stats['requests']['simulation'] += session.request_counts['simulation']
568
+
569
+ return stats
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gymnasium==0.29.1
2
+ ale-py==0.9.0
3
+ h5py==3.11.0
4
+ huggingface-hub==0.17.2
5
+ hydra-core==1.3
6
+ numpy==1.26.0
7
+ opencv-python==4.10.0.84
8
+ pillow==10.3.0
9
+ pygame==2.5.2
10
+ torch==2.1.0
11
+ torchvision==0.16.0
12
+ torcheval==0.0.7
13
+ tqdm==4.66.4
14
+ wandb==0.17.0
scripts/import_run.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+
3
+ import argparse
4
+ from functools import partial
5
+ import json
6
+ from pathlib import Path
7
+ import subprocess
8
+ from typing import Optional
9
+
10
+
11
+ def main() -> None:
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("host", type=str)
14
+ parser.add_argument("-v", "--verbose", action="store_true")
15
+ parser.add_argument("--user", type=Optional[str])
16
+ parser.add_argument("--rootdir", type=Optional[str])
17
+ args = parser.parse_args()
18
+
19
+ run = partial(subprocess.run, shell=True, check=True, text=True)
20
+ host = args.host if args.user is None else f"{args.user}@{args.host}"
21
+
22
+ def run_remote_cmd(cmd):
23
+ return subprocess.check_output(f"ssh {host} {cmd}", shell=True, text=True)
24
+
25
+ def ls(p):
26
+ out = run_remote_cmd(f"ls {p}")
27
+ return out.strip().split("\n")[::-1]
28
+
29
+ def ask(l, info=None):
30
+ print(
31
+ "\n".join(
32
+ [
33
+ f"{i:{len(str(len(l)))}d}: {d}"
34
+ + (f" ({info[d]})" if info is not None else "")
35
+ for i, d in enumerate(l, 1)
36
+ ]
37
+ )
38
+ )
39
+ while True:
40
+ i = input("\nEnter a number: ")
41
+ if i.isdigit() and 1 <= int(i) <= len(l):
42
+ break
43
+ print("\n/!\\ Invalid choice\n")
44
+ return l[int(i) - 1]
45
+
46
+ def ask_if_verbose(question, default):
47
+ if not args.verbose:
48
+ return default
49
+ suffix = "[Y|n]" if default else "[y|N]"
50
+ answer = input(f"{question} {suffix} ").lower()
51
+
52
+ return (answer != "n") if default else (answer == "y")
53
+
54
+ def get_info(rundir):
55
+ return json.loads(
56
+ run_remote_cmd(f"cat {rundir}/checkpoints/info_for_import_script.json")
57
+ )
58
+
59
+ if args.rootdir is None:
60
+ for p in Path(__file__).resolve().parents:
61
+ if (p / ".git").is_dir():
62
+ break
63
+ else:
64
+ raise RuntimeError("This file is not in a git repository")
65
+ out = run_remote_cmd(f"find -type d -name {p.name}").strip().split("\n")
66
+ assert len(out) == 1
67
+ rootdir = out[0]
68
+ else:
69
+ rootdir = f'{args.rootdir.strip().strip("/")}'
70
+
71
+ dates = ls(f"{rootdir}/outputs")
72
+ date = ask(dates)
73
+ times = ls(f"{rootdir}/outputs/{date}")
74
+
75
+ infos = {
76
+ time: get_info(rundir=f"{rootdir}/outputs/{date}/{time}") for time in times
77
+ }
78
+ time = ask(times, infos)
79
+
80
+ src = f"{rootdir}/outputs/{date}/{time}"
81
+
82
+ dst = Path(args.host) / date
83
+ dst.mkdir(exist_ok=True, parents=True)
84
+
85
+ exclude = [
86
+ "*.log",
87
+ "checkpoints/*",
88
+ "checkpoints_tmp",
89
+ ".hydra",
90
+ "media",
91
+ "__pycache__",
92
+ "wandb",
93
+ ]
94
+
95
+ include = ["checkpoints/agent_versions"]
96
+
97
+ if ask_if_verbose("Download only last checkpoint?", default=True):
98
+ last_ckpt = ls(f"{src}/checkpoints/agent_versions")[0]
99
+ exclude.append("checkpoints/agent_versions/*")
100
+ include.append(f"checkpoints/agent_versions/{last_ckpt}")
101
+
102
+ if not ask_if_verbose("Download train dataset?", default=False):
103
+ exclude.append("dataset/train")
104
+
105
+ if not ask_if_verbose("Download test dataset?", default=False):
106
+ exclude.append("dataset/test")
107
+
108
+ cmd = "rsync -av"
109
+ for i in include:
110
+ cmd += f' --include="{i}"'
111
+ for e in exclude:
112
+ cmd += f' --exclude="{e}"'
113
+
114
+ cmd += f" {host}:{src} {str(dst)}"
115
+ run(cmd)
116
+
117
+ path = (dst / time).absolute()
118
+ print(f"\n--> Run imported in:\n{path}")
119
+ run(f"echo {path} | xclip")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
scripts/resume.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python src/main.py common.resume=True hydra.output_subdir=null hydra.run.dir=.
server.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ AI Game Multiverse Cloud Gaming Server
6
+
7
+ This script implements a websocket server for the AI Game Multiverse project,
8
+ allowing real-time streaming of game frames based on player inputs.
9
+ """
10
+
11
+ import asyncio
12
+ import json
13
+ import logging
14
+ import os
15
+ import pathlib
16
+ import time
17
+ import uuid
18
+ import base64
19
+ import argparse
20
+ from typing import Dict, List, Any, Optional
21
+ from aiohttp import web, WSMsgType
22
+
23
+ # Configure logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+ class SimpleGameEngine:
31
+ """
32
+ A simple game engine that generates placeholder frames.
33
+ This is used when the main model engine is not available.
34
+ """
35
+ def __init__(self, args=None):
36
+ self.frame_width = getattr(args, 'frame_width', 640)
37
+ self.frame_height = getattr(args, 'frame_height', 360)
38
+ self.frame_count = 0
39
+
40
+ # Create placeholder scenes
41
+ self.scenes = {}
42
+ self._create_placeholder_scenes()
43
+
44
+ def _create_placeholder_scenes(self):
45
+ """Create placeholder scene frames for demo purposes"""
46
+ scene_names = ['forest', 'desert', 'beach', 'hills', 'river', 'plain']
47
+
48
+ for scene_name in scene_names:
49
+ frames = []
50
+ for i in range(5): # Create 5 frames per scene
51
+ import numpy as np
52
+ import cv2
53
+
54
+ # Create a colored frame based on scene name
55
+ if scene_name == 'forest':
56
+ color = (34, 139, 34) # Forest green
57
+ elif scene_name == 'desert':
58
+ color = (210, 180, 140) # Desert sand
59
+ elif scene_name == 'beach':
60
+ color = (238, 214, 175) # Beach sand
61
+ elif scene_name == 'hills':
62
+ color = (85, 107, 47) # Olive green
63
+ elif scene_name == 'river':
64
+ color = (65, 105, 225) # Royal blue
65
+ else:
66
+ color = (160, 160, 160) # Gray
67
+
68
+ # Create base frame
69
+ frame = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8)
70
+ frame[:] = color
71
+
72
+ # Add scene name and frame number
73
+ cv2.putText(frame, f"Scene: {scene_name}", (50, 180),
74
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
75
+ cv2.putText(frame, f"Frame {i}", (50, 220),
76
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
77
+
78
+ frames.append(frame)
79
+
80
+ self.scenes[scene_name] = frames
81
+
82
+ def get_valid_scenes(self) -> List[str]:
83
+ """
84
+ Get a list of valid scene names.
85
+
86
+ Returns:
87
+ List[str]: List of valid scene names
88
+ """
89
+ return list(self.scenes.keys())
90
+
91
+ def generate_frame(self, scene_name: str, keyboard_condition: Optional[List] = None,
92
+ mouse_condition: Optional[List] = None) -> bytes:
93
+ """
94
+ Generate a simple frame based on the scene and input conditions.
95
+
96
+ Args:
97
+ scene_name: Name of the current scene
98
+ keyboard_condition: Keyboard input state
99
+ mouse_condition: Mouse input state
100
+
101
+ Returns:
102
+ bytes: JPEG bytes of the frame
103
+ """
104
+ import numpy as np
105
+ import cv2
106
+
107
+ scene_frames = self.scenes.get(scene_name, self.scenes['forest'])
108
+ frame_idx = self.frame_count % len(scene_frames)
109
+ frame = scene_frames[frame_idx].copy()
110
+ self.frame_count += 1
111
+
112
+ # Add visualization of input controls
113
+ frame = self._visualize_controls(frame, keyboard_condition, mouse_condition)
114
+
115
+ # Convert frame to JPEG
116
+ success, buffer = cv2.imencode('.jpg', frame)
117
+ if not success:
118
+ # Return a blank frame
119
+ blank = np.ones((self.frame_height, self.frame_width, 3), dtype=np.uint8) * 100
120
+ success, buffer = cv2.imencode('.jpg', blank)
121
+
122
+ return buffer.tobytes()
123
+
124
+ def _visualize_controls(self, frame: np.ndarray, keyboard_condition: List, mouse_condition: List) -> np.ndarray:
125
+ """Visualize keyboard and mouse controls on the frame."""
126
+ import cv2
127
+
128
+ # Clone the frame to avoid modifying the original
129
+ frame = frame.copy()
130
+
131
+ # If we have keyboard/mouse conditions, visualize them on the frame
132
+ if keyboard_condition:
133
+ # Visualize keyboard inputs
134
+ keys = ["W", "S", "A", "D", "JUMP", "ATTACK"]
135
+ for i, key_pressed in enumerate(keyboard_condition[0]):
136
+ color = (0, 255, 0) if key_pressed else (100, 100, 100)
137
+ cv2.putText(frame, keys[i], (20 + i*100, 30),
138
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
139
+
140
+ if mouse_condition:
141
+ # Visualize mouse movement
142
+ mouse_x, mouse_y = mouse_condition[0]
143
+ # Scale mouse values for visualization
144
+ offset_x = int(mouse_x * 100)
145
+ offset_y = int(mouse_y * 100)
146
+ center_x, center_y = self.frame_width // 2, self.frame_height // 2
147
+ cv2.circle(frame, (center_x + offset_x, center_y - offset_y), 10, (255, 0, 0), -1)
148
+ cv2.putText(frame, f"Mouse: {mouse_x:.2f}, {mouse_y:.2f}",
149
+ (self.frame_width - 250, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
150
+
151
+ return frame
152
+
153
+ class GameSession:
154
+ """
155
+ Represents a user's gaming session.
156
+ Each WebSocket connection gets its own session with separate queues.
157
+ """
158
+ def __init__(self, user_id: str, ws: web.WebSocketResponse, game_manager):
159
+ self.user_id = user_id
160
+ self.ws = ws
161
+ self.game_manager = game_manager
162
+
163
+ # Create action queue for this user session
164
+ self.action_queue = asyncio.Queue()
165
+
166
+ # Session creation time
167
+ self.created_at = time.time()
168
+ self.last_activity = time.time()
169
+
170
+ # Game state
171
+ self.current_scene = "forest" # Default scene
172
+ self.is_streaming = False
173
+ self.stream_task = None
174
+
175
+ # Current input state
176
+ self.keyboard_state = [0, 0, 0, 0, 0, 0] # forward, back, left, right, jump, attack
177
+ self.mouse_state = [0, 0] # x, y
178
+
179
+ self.background_tasks = []
180
+
181
+ async def start(self):
182
+ """Start all the queue processors for this session"""
183
+ self.background_tasks = [
184
+ asyncio.create_task(self._process_action_queue()),
185
+ ]
186
+ logger.info(f"Started game session for user {self.user_id}")
187
+
188
+ async def stop(self):
189
+ """Stop all background tasks for this session"""
190
+ # Stop streaming if active
191
+ if self.is_streaming and self.stream_task:
192
+ self.is_streaming = False
193
+ self.stream_task.cancel()
194
+ try:
195
+ await self.stream_task
196
+ except asyncio.CancelledError:
197
+ pass
198
+
199
+ # Cancel other background tasks
200
+ for task in self.background_tasks:
201
+ task.cancel()
202
+
203
+ try:
204
+ # Wait for tasks to complete cancellation
205
+ await asyncio.gather(*self.background_tasks, return_exceptions=True)
206
+ except asyncio.CancelledError:
207
+ pass
208
+
209
+ logger.info(f"Stopped game session for user {self.user_id}")
210
+
211
+ async def _process_action_queue(self):
212
+ """Process game actions from the queue"""
213
+ while True:
214
+ data = await self.action_queue.get()
215
+ try:
216
+ action_type = data.get('action')
217
+
218
+ if action_type == 'start_stream':
219
+ result = await self._handle_start_stream(data)
220
+ elif action_type == 'stop_stream':
221
+ result = await self._handle_stop_stream(data)
222
+ elif action_type == 'keyboard_input':
223
+ result = await self._handle_keyboard_input(data)
224
+ elif action_type == 'mouse_input':
225
+ result = await self._handle_mouse_input(data)
226
+ elif action_type == 'change_scene':
227
+ result = await self._handle_scene_change(data)
228
+ else:
229
+ result = {
230
+ 'action': action_type,
231
+ 'requestId': data.get('requestId'),
232
+ 'success': False,
233
+ 'error': f'Unknown action: {action_type}'
234
+ }
235
+
236
+ # Send response back to the client
237
+ await self.ws.send_json(result)
238
+
239
+ # Update last activity time
240
+ self.last_activity = time.time()
241
+
242
+ except Exception as e:
243
+ logger.error(f"Error processing action for user {self.user_id}: {str(e)}")
244
+ try:
245
+ await self.ws.send_json({
246
+ 'action': data.get('action'),
247
+ 'requestId': data.get('requestId', 'unknown'),
248
+ 'success': False,
249
+ 'error': f'Error processing action: {str(e)}'
250
+ })
251
+ except Exception as send_error:
252
+ logger.error(f"Error sending error response: {send_error}")
253
+ finally:
254
+ self.action_queue.task_done()
255
+
256
+ async def _handle_start_stream(self, data: Dict) -> Dict:
257
+ """Handle request to start streaming frames"""
258
+ if self.is_streaming:
259
+ return {
260
+ 'action': 'start_stream',
261
+ 'requestId': data.get('requestId'),
262
+ 'success': False,
263
+ 'error': 'Stream already active'
264
+ }
265
+
266
+ fps = data.get('fps', 16)
267
+ self.is_streaming = True
268
+ self.stream_task = asyncio.create_task(self._stream_frames(fps))
269
+
270
+ return {
271
+ 'action': 'start_stream',
272
+ 'requestId': data.get('requestId'),
273
+ 'success': True,
274
+ 'message': f'Streaming started at {fps} FPS'
275
+ }
276
+
277
+ async def _handle_stop_stream(self, data: Dict) -> Dict:
278
+ """Handle request to stop streaming frames"""
279
+ if not self.is_streaming:
280
+ return {
281
+ 'action': 'stop_stream',
282
+ 'requestId': data.get('requestId'),
283
+ 'success': False,
284
+ 'error': 'No active stream to stop'
285
+ }
286
+
287
+ self.is_streaming = False
288
+ if self.stream_task:
289
+ self.stream_task.cancel()
290
+ try:
291
+ await self.stream_task
292
+ except asyncio.CancelledError:
293
+ pass
294
+ self.stream_task = None
295
+
296
+ return {
297
+ 'action': 'stop_stream',
298
+ 'requestId': data.get('requestId'),
299
+ 'success': True,
300
+ 'message': 'Streaming stopped'
301
+ }
302
+
303
+ async def _handle_keyboard_input(self, data: Dict) -> Dict:
304
+ """Handle keyboard input from client"""
305
+ key = data.get('key', '')
306
+ pressed = data.get('pressed', False)
307
+
308
+ # Map key to keyboard state index
309
+ key_map = {
310
+ 'w': 0, 'forward': 0,
311
+ 's': 1, 'back': 1, 'backward': 1,
312
+ 'a': 2, 'left': 2,
313
+ 'd': 3, 'right': 3,
314
+ 'space': 4, 'jump': 4,
315
+ 'shift': 5, 'attack': 5, 'ctrl': 5
316
+ }
317
+
318
+ if key.lower() in key_map:
319
+ key_idx = key_map[key.lower()]
320
+ self.keyboard_state[key_idx] = 1 if pressed else 0
321
+
322
+ return {
323
+ 'action': 'keyboard_input',
324
+ 'requestId': data.get('requestId'),
325
+ 'success': True,
326
+ 'keyboardState': self.keyboard_state
327
+ }
328
+
329
+ async def _handle_mouse_input(self, data: Dict) -> Dict:
330
+ """Handle mouse movement/input from client"""
331
+ mouse_x = data.get('x', 0)
332
+ mouse_y = data.get('y', 0)
333
+
334
+ # Update mouse state, normalize values between -1 and 1
335
+ self.mouse_state = [float(mouse_x), float(mouse_y)]
336
+
337
+ return {
338
+ 'action': 'mouse_input',
339
+ 'requestId': data.get('requestId'),
340
+ 'success': True,
341
+ 'mouseState': self.mouse_state
342
+ }
343
+
344
+ async def _handle_scene_change(self, data: Dict) -> Dict:
345
+ """Handle scene change requests"""
346
+ scene_name = data.get('scene', 'forest')
347
+ valid_scenes = self.game_manager.valid_scenes
348
+
349
+ if scene_name not in valid_scenes:
350
+ return {
351
+ 'action': 'change_scene',
352
+ 'requestId': data.get('requestId'),
353
+ 'success': False,
354
+ 'error': f'Invalid scene: {scene_name}. Valid scenes are: {", ".join(valid_scenes)}'
355
+ }
356
+
357
+ self.current_scene = scene_name
358
+
359
+ return {
360
+ 'action': 'change_scene',
361
+ 'requestId': data.get('requestId'),
362
+ 'success': True,
363
+ 'scene': scene_name
364
+ }
365
+
366
+ async def _stream_frames(self, fps: int):
367
+ """Stream frames to the client at the specified FPS"""
368
+ frame_interval = 1.0 / fps # Time between frames in seconds
369
+
370
+ try:
371
+ while self.is_streaming:
372
+ start_time = time.time()
373
+
374
+ # Generate frame based on current keyboard and mouse state
375
+ keyboard_condition = [self.keyboard_state]
376
+ mouse_condition = [self.mouse_state]
377
+
378
+ # Use the engine to generate the next frame
379
+ frame_bytes = self.game_manager.engine.generate_frame(
380
+ self.current_scene, keyboard_condition, mouse_condition
381
+ )
382
+
383
+ # Encode as base64 for sending in JSON
384
+ frame_base64 = base64.b64encode(frame_bytes).decode('utf-8')
385
+
386
+ # Send frame to client
387
+ await self.ws.send_json({
388
+ 'action': 'frame',
389
+ 'frameData': frame_base64,
390
+ 'timestamp': time.time()
391
+ })
392
+
393
+ # Calculate sleep time to maintain FPS
394
+ elapsed = time.time() - start_time
395
+ sleep_time = max(0, frame_interval - elapsed)
396
+ await asyncio.sleep(sleep_time)
397
+
398
+ except asyncio.CancelledError:
399
+ logger.info(f"Frame streaming cancelled for user {self.user_id}")
400
+ except Exception as e:
401
+ logger.error(f"Error in frame streaming for user {self.user_id}: {str(e)}")
402
+ if self.ws.closed:
403
+ logger.info(f"WebSocket closed for user {self.user_id}")
404
+ return
405
+
406
+ # Notify client of error
407
+ try:
408
+ await self.ws.send_json({
409
+ 'action': 'frame_error',
410
+ 'error': f'Streaming error: {str(e)}'
411
+ })
412
+ except:
413
+ pass
414
+
415
+ # Stop streaming
416
+ self.is_streaming = False
417
+
418
+ class GameManager:
419
+ """
420
+ Manages all active gaming sessions and shared resources.
421
+ """
422
+ def __init__(self, args: argparse.Namespace):
423
+ self.sessions = {}
424
+ self.session_lock = asyncio.Lock()
425
+
426
+ # Try to import and initialize the game engine
427
+ try:
428
+ # Dynamically import the real engine if available
429
+ from src.envs.world_model_env import WorldModelEnv
430
+ # Initialize with model from args
431
+ self.engine = WorldModelEnv(args)
432
+ logger.info("Initialized World Model Environment")
433
+ except ImportError:
434
+ logger.warning("Could not import World Model Environment, falling back to simple engine")
435
+ self.engine = SimpleGameEngine(args)
436
+ except Exception as e:
437
+ logger.error(f"Error initializing World Model Environment: {str(e)}")
438
+ logger.warning("Falling back to simple engine")
439
+ self.engine = SimpleGameEngine(args)
440
+
441
+ # Load valid scenes from engine
442
+ self.valid_scenes = self.engine.get_valid_scenes()
443
+
444
+ async def create_session(self, user_id: str, ws: web.WebSocketResponse) -> GameSession:
445
+ """Create a new game session"""
446
+ async with self.session_lock:
447
+ # Create a new session for this user
448
+ session = GameSession(user_id, ws, self)
449
+ await session.start()
450
+ self.sessions[user_id] = session
451
+ return session
452
+
453
+ async def delete_session(self, user_id: str) -> None:
454
+ """Delete a game session and clean up resources"""
455
+ async with self.session_lock:
456
+ if user_id in self.sessions:
457
+ session = self.sessions[user_id]
458
+ await session.stop()
459
+ del self.sessions[user_id]
460
+ logger.info(f"Deleted game session for user {user_id}")
461
+
462
+ def get_session(self, user_id: str) -> Optional[GameSession]:
463
+ """Get a game session if it exists"""
464
+ return self.sessions.get(user_id)
465
+
466
+ async def close_all_sessions(self) -> None:
467
+ """Close all active sessions (used during shutdown)"""
468
+ async with self.session_lock:
469
+ for user_id, session in list(self.sessions.items()):
470
+ await session.stop()
471
+ self.sessions.clear()
472
+ logger.info("Closed all active game sessions")
473
+
474
+ @property
475
+ def session_count(self) -> int:
476
+ """Get the number of active sessions"""
477
+ return len(self.sessions)
478
+
479
+ def get_session_stats(self) -> Dict:
480
+ """Get statistics about active sessions"""
481
+ stats = {
482
+ 'total_sessions': len(self.sessions),
483
+ 'active_scenes': {},
484
+ 'streaming_sessions': 0
485
+ }
486
+
487
+ # Count sessions by scene and streaming status
488
+ for session in self.sessions.values():
489
+ scene = session.current_scene
490
+ stats['active_scenes'][scene] = stats['active_scenes'].get(scene, 0) + 1
491
+ if session.is_streaming:
492
+ stats['streaming_sessions'] += 1
493
+
494
+ return stats
495
+
496
+ # Create global game manager
497
+ game_manager = None
498
+
499
+ async def status_handler(request: web.Request) -> web.Response:
500
+ """Handler for API status endpoint"""
501
+ # Get session statistics
502
+ session_stats = game_manager.get_session_stats()
503
+
504
+ return web.json_response({
505
+ 'product': 'AI Game Multiverse Server',
506
+ 'version': '1.0.0',
507
+ 'active_sessions': session_stats,
508
+ 'available_scenes': game_manager.valid_scenes
509
+ })
510
+
511
+ async def root_handler(request: web.Request) -> web.Response:
512
+ """Handler for serving the client at the root path"""
513
+ index_path = pathlib.Path(__file__).parent / 'index.html'
514
+
515
+ if not index_path.exists():
516
+ return web.Response(text="""
517
+ <html>
518
+ <body style="font-family: Arial, sans-serif; text-align: center; padding: 50px;">
519
+ <h1>AI Game Multiverse Server</h1>
520
+ <p>Server is running, but the index.html file is missing.</p>
521
+ <p>Please create the index.html file in the same directory as the server.py file.</p>
522
+ </body>
523
+ </html>
524
+ """, content_type='text/html')
525
+
526
+ with open(index_path, 'r') as file:
527
+ html_content = file.read()
528
+
529
+ return web.Response(text=html_content, content_type='text/html')
530
+
531
+ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
532
+ """Handle WebSocket connections with robust error handling"""
533
+ logger.info(f"WebSocket connection attempt - PATH: {request.path}, QUERY: {request.query_string}")
534
+
535
+ # Log request headers at debug level only (could contain sensitive information)
536
+ logger.debug(f"WebSocket request headers: {dict(request.headers)}")
537
+
538
+ # Prepare a WebSocket response with appropriate settings
539
+ ws = web.WebSocketResponse(
540
+ max_msg_size=1024*1024*10, # 10MB max message size
541
+ timeout=60.0,
542
+ heartbeat=30.0 # Add heartbeat to keep connection alive
543
+ )
544
+
545
+ # Check if WebSocket protocol is supported
546
+ if not ws.can_prepare(request):
547
+ logger.error("Cannot prepare WebSocket: WebSocket protocol not supported")
548
+ return web.Response(status=400, text="WebSocket protocol not supported")
549
+
550
+ try:
551
+ logger.info("Preparing WebSocket connection...")
552
+ await ws.prepare(request)
553
+
554
+ # Generate a unique user ID for this connection
555
+ user_id = str(uuid.uuid4())
556
+
557
+ # Get client IP address
558
+ peername = request.transport.get_extra_info('peername')
559
+ if peername is not None:
560
+ client_ip = peername[0]
561
+ else:
562
+ client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
563
+
564
+ # Log connection success
565
+ logger.info(f"Client {user_id} connecting from IP: {client_ip} - WebSocket connection established")
566
+
567
+ # Mark that the session is established
568
+ is_session_created = False
569
+
570
+ try:
571
+ # Store the user ID in the websocket for easy access
572
+ ws.user_id = user_id
573
+
574
+ # Create a new session for this user
575
+ logger.info(f"Creating game session for user {user_id}")
576
+ user_session = await game_manager.create_session(user_id, ws)
577
+ is_session_created = True
578
+ logger.info(f"Game session created for user {user_id}")
579
+ except Exception as session_error:
580
+ logger.error(f"Error creating game session: {str(session_error)}", exc_info=True)
581
+ if not ws.closed:
582
+ await ws.close(code=1011, message=f"Server error: {str(session_error)}".encode())
583
+ if is_session_created:
584
+ await game_manager.delete_session(user_id)
585
+ return ws
586
+ except Exception as e:
587
+ logger.error(f"Error establishing WebSocket connection: {str(e)}", exc_info=True)
588
+ if not ws.closed and ws.prepared:
589
+ await ws.close(code=1011, message=f"Server error: {str(e)}".encode())
590
+ return ws
591
+
592
+ # Send initial welcome message
593
+ try:
594
+ await ws.send_json({
595
+ 'action': 'welcome',
596
+ 'userId': user_id,
597
+ 'message': 'Welcome to the AI Game Multiverse WebSocket server!',
598
+ 'scenes': game_manager.valid_scenes
599
+ })
600
+ logger.info(f"Sent welcome message to user {user_id}")
601
+ except Exception as welcome_error:
602
+ logger.error(f"Error sending welcome message: {str(welcome_error)}")
603
+ if not ws.closed:
604
+ await ws.close(code=1011, message=b"Failed to send welcome message")
605
+ await game_manager.delete_session(user_id)
606
+ return ws
607
+
608
+ try:
609
+ async for msg in ws:
610
+ if msg.type == WSMsgType.TEXT:
611
+ try:
612
+ data = json.loads(msg.data)
613
+ action = data.get('action')
614
+
615
+ logger.debug(f"Received {action} message from user {user_id}")
616
+
617
+ if action == 'ping':
618
+ # Respond to ping immediately
619
+ await ws.send_json({
620
+ 'action': 'pong',
621
+ 'requestId': data.get('requestId'),
622
+ 'timestamp': time.time()
623
+ })
624
+ else:
625
+ # Route game actions to the session's action queue
626
+ await user_session.action_queue.put(data)
627
+
628
+ except json.JSONDecodeError:
629
+ logger.error(f"Invalid JSON from user {user_id}: {msg.data}")
630
+ if not ws.closed:
631
+ await ws.send_json({
632
+ 'error': 'Invalid JSON message',
633
+ 'success': False
634
+ })
635
+ except Exception as e:
636
+ logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
637
+ if not ws.closed:
638
+ await ws.send_json({
639
+ 'action': data.get('action') if 'data' in locals() else 'unknown',
640
+ 'success': False,
641
+ 'error': f'Error processing message: {str(e)}'
642
+ })
643
+
644
+ elif msg.type == WSMsgType.ERROR:
645
+ logger.error(f"WebSocket error for user {user_id}: {ws.exception()}")
646
+ break
647
+
648
+ elif msg.type == WSMsgType.CLOSE:
649
+ logger.info(f"WebSocket close received for user {user_id} (code: {msg.data}, message: {msg.extra})")
650
+ break
651
+
652
+ elif msg.type == WSMsgType.CLOSING:
653
+ logger.info(f"WebSocket closing for user {user_id}")
654
+ break
655
+
656
+ elif msg.type == WSMsgType.CLOSED:
657
+ logger.info(f"WebSocket already closed for user {user_id}")
658
+ break
659
+
660
+ except Exception as ws_error:
661
+ logger.error(f"Unexpected WebSocket error for user {user_id}: {str(ws_error)}", exc_info=True)
662
+ finally:
663
+ # Cleanup session
664
+ try:
665
+ logger.info(f"Cleaning up session for user {user_id}")
666
+ await game_manager.delete_session(user_id)
667
+ logger.info(f"Connection closed for user {user_id}")
668
+ except Exception as cleanup_error:
669
+ logger.error(f"Error during session cleanup for user {user_id}: {str(cleanup_error)}")
670
+
671
+ return ws
672
+
673
+ async def init_app(args, base_path="") -> web.Application:
674
+ """Initialize the web application"""
675
+ global game_manager
676
+
677
+ # Initialize game manager with command line args
678
+ game_manager = GameManager(args)
679
+
680
+ app = web.Application(
681
+ client_max_size=1024**2*10 # 10MB max size
682
+ )
683
+
684
+ # Add cleanup logic
685
+ async def cleanup(app):
686
+ logger.info("Shutting down server, closing all sessions...")
687
+ await game_manager.close_all_sessions()
688
+
689
+ app.on_shutdown.append(cleanup)
690
+
691
+ # Add routes with CORS headers for WebSockets
692
+ # Configure CORS for all routes
693
+ @web.middleware
694
+ async def cors_middleware(request, handler):
695
+ if request.method == 'OPTIONS':
696
+ # Handle preflight requests
697
+ resp = web.Response()
698
+ resp.headers['Access-Control-Allow-Origin'] = '*'
699
+ resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
700
+ resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
701
+ return resp
702
+
703
+ # Normal request, call the handler
704
+ resp = await handler(request)
705
+
706
+ # Add CORS headers to the response
707
+ resp.headers['Access-Control-Allow-Origin'] = '*'
708
+ resp.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
709
+ resp.headers['Access-Control-Allow-Headers'] = 'Content-Type, X-Requested-With'
710
+ return resp
711
+
712
+ app.middlewares.append(cors_middleware)
713
+
714
+ # Add a debug endpoint to help diagnose WebSocket issues
715
+ async def debug_handler(request):
716
+ client_ip = request.remote
717
+ headers = dict(request.headers)
718
+ server_host = request.host
719
+
720
+ debug_info = {
721
+ "client_ip": client_ip,
722
+ "server_host": server_host,
723
+ "headers": headers,
724
+ "request_path": request.path,
725
+ "server_time": time.time(),
726
+ "base_path": base_path,
727
+ "websocket_route": f"{base_path}/ws",
728
+ "all_routes": [route.name for route in app.router.routes() if route.name],
729
+ "server_info": {
730
+ "active_sessions": game_manager.session_count,
731
+ "available_scenes": game_manager.valid_scenes
732
+ }
733
+ }
734
+
735
+ return web.json_response(debug_info)
736
+
737
+ # Set up routes with the base_path
738
+ # Add multiple WebSocket routes to ensure compatibility
739
+ logger.info(f"Setting up WebSocket route at {base_path}/ws")
740
+ app.router.add_get(f'{base_path}/ws', websocket_handler, name='ws_handler')
741
+
742
+ # Also add WebSocket route at the root for compatibility
743
+ if base_path:
744
+ logger.info(f"Adding additional WebSocket route at /ws")
745
+ app.router.add_get('/ws', websocket_handler, name='ws_root_handler')
746
+
747
+ # Add routes for API and debug endpoints
748
+ app.router.add_get(f'{base_path}/api/status', status_handler, name='status_handler')
749
+ app.router.add_get(f'{base_path}/api/debug', debug_handler, name='debug_handler')
750
+
751
+ # Serve the client at both the base path and root path for compatibility
752
+ app.router.add_get(f'{base_path}/', root_handler, name='root_handler')
753
+
754
+ # Always serve at the root path for compatibility
755
+ if base_path:
756
+ app.router.add_get('/', root_handler, name='root_handler_no_base')
757
+
758
+ # Set up static file serving for assets
759
+ static_path = pathlib.Path(__file__).parent / 'assets'
760
+ if not static_path.exists():
761
+ static_path.mkdir(exist_ok=True)
762
+
763
+ app.router.add_static(f'{base_path}/assets', static_path, name='static_handler')
764
+
765
+ # Add static file serving at root for compatibility
766
+ if base_path:
767
+ app.router.add_static('/assets', static_path, name='static_handler_no_base')
768
+
769
+ return app
770
+
771
+ def parse_args() -> argparse.Namespace:
772
+ """Parse server-specific command line arguments"""
773
+ parser = argparse.ArgumentParser(description="AI Game Multiverse Server")
774
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host IP to bind to")
775
+ parser.add_argument("--port", type=int, default=8080, help="Port to listen on")
776
+ parser.add_argument("--path", type=str, default="", help="Base path for the server (for proxy setups)")
777
+
778
+ # Add model-specific arguments
779
+ parser.add_argument("--frame_width", type=int, default=640, help="Width of output frames")
780
+ parser.add_argument("--frame_height", type=int, default=360, help="Height of output frames")
781
+ parser.add_argument("--fps", type=int, default=16, help="Target frames per second")
782
+
783
+ args = parser.parse_args()
784
+ return args
785
+
786
+ if __name__ == '__main__':
787
+ # Parse command line arguments
788
+ args = parse_args()
789
+
790
+ # Initialize app
791
+ loop = asyncio.get_event_loop()
792
+ app = loop.run_until_complete(init_app(args, base_path=args.path))
793
+
794
+ # Start server
795
+ logger.info(f"Starting AI Game Multiverse Server at {args.host}:{args.port}")
796
+ web.run_app(app, host=args.host, port=args.port)
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
File without changes
src/agent.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from envs import TorchEnv, WorldModelEnv
9
+ from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
10
+ from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
11
+ from models.rew_end_model import RewEndModel, RewEndModelConfig
12
+ from utils import extract_state_dict
13
+
14
+
15
+ @dataclass
16
+ class AgentConfig:
17
+ denoiser: DenoiserConfig
18
+ upsampler: Optional[DenoiserConfig]
19
+ rew_end_model: Optional[RewEndModelConfig]
20
+ actor_critic: Optional[ActorCriticConfig]
21
+ num_actions: int
22
+
23
+ def __post_init__(self) -> None:
24
+ self.denoiser.inner_model.num_actions = self.num_actions
25
+ if self.upsampler is not None:
26
+ self.upsampler.inner_model.num_actions = self.num_actions
27
+ if self.rew_end_model is not None:
28
+ self.rew_end_model.num_actions = self.num_actions
29
+ if self.actor_critic is not None:
30
+ self.actor_critic.num_actions = self.num_actions
31
+
32
+
33
+ class Agent(nn.Module):
34
+ def __init__(self, cfg: AgentConfig) -> None:
35
+ super().__init__()
36
+ self.denoiser = Denoiser(cfg.denoiser)
37
+ self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None
38
+ self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None
39
+ self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None
40
+
41
+ @property
42
+ def device(self):
43
+ return self.denoiser.device
44
+
45
+ def setup_training(
46
+ self,
47
+ sigma_distribution_cfg: SigmaDistributionConfig,
48
+ sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig],
49
+ actor_critic_loss_cfg: Optional[ActorCriticLossConfig],
50
+ rl_env: Optional[Union[TorchEnv, WorldModelEnv]],
51
+ ) -> None:
52
+ self.denoiser.setup_training(sigma_distribution_cfg)
53
+ if self.upsampler is not None:
54
+ self.upsampler.setup_training(sigma_distribution_cfg_upsampler)
55
+ if self.actor_critic is not None:
56
+ self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)
57
+
58
+ def load(
59
+ self,
60
+ path_to_ckpt: Path,
61
+ load_denoiser: bool = True,
62
+ load_upsampler: bool = True,
63
+ load_rew_end_model: bool = True,
64
+ load_actor_critic: bool = True,
65
+ ) -> None:
66
+ sd = torch.load(Path(path_to_ckpt), map_location=self.device)
67
+ if load_denoiser:
68
+ self.denoiser.load_state_dict(extract_state_dict(sd, "denoiser"))
69
+ if load_upsampler:
70
+ self.upsampler.load_state_dict(extract_state_dict(sd, "upsampler"))
71
+ if load_rew_end_model and self.rew_end_model is not None:
72
+ self.rew_end_model.load_state_dict(extract_state_dict(sd, "rew_end_model"))
73
+ if load_actor_critic and self.actor_critic is not None:
74
+ self.actor_critic.load_state_dict(extract_state_dict(sd, "actor_critic"))
src/coroutines/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+
3
+
4
+ def coroutine(func):
5
+ @wraps(func)
6
+ def primer(*args, **kwargs):
7
+ gen = func(*args, **kwargs)
8
+ next(gen)
9
+ return gen
10
+
11
+ return primer
src/coroutines/collector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Generator, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ from . import coroutine
10
+ from data import Episode, Dataset
11
+ from envs import TorchEnv
12
+ from .env_loop import make_env_loop
13
+ from utils import Logs
14
+
15
+
16
+ @coroutine
17
+ def make_collector(
18
+ env: TorchEnv,
19
+ model: nn.Module,
20
+ dataset: Dataset,
21
+ epsilon: float = 0.0,
22
+ reset_every_collect: bool = False,
23
+ verbose: bool = True,
24
+ ) -> Generator[Logs, int, None]:
25
+ num_envs = env.num_envs
26
+
27
+ env_loop, buffer, episode_ids, dead = (None,) * 4
28
+ num_steps, num_episodes, to_log, pbar = (None,) * 4
29
+
30
+ def setup_new_collect():
31
+ nonlocal num_steps, num_episodes, buffer, to_log, pbar
32
+ num_steps = 0
33
+ num_episodes = 0
34
+ buffer = defaultdict(list)
35
+ to_log = []
36
+ pbar = tqdm(
37
+ total=num_to_collect.total,
38
+ unit=num_to_collect.unit,
39
+ desc=f"Collect {dataset.name}",
40
+ disable=not verbose,
41
+ )
42
+
43
+ def reset():
44
+ nonlocal env_loop, episode_ids, dead
45
+ env_loop = make_env_loop(env, model, epsilon)
46
+ episode_ids = defaultdict(lambda: None)
47
+ dead = [None] * num_envs
48
+
49
+ num_to_collect = yield
50
+ setup_new_collect()
51
+ reset()
52
+
53
+ while True:
54
+ with torch.no_grad():
55
+ all_obs, act, rew, end, trunc, *_, [infos] = env_loop.send(1)
56
+
57
+ num_steps += num_envs
58
+ pbar.update(num_envs if num_to_collect.steps is not None else 0)
59
+
60
+ for i, (o, a, r, e, t) in enumerate(zip(all_obs, act, rew, end, trunc)):
61
+ buffer[i].append((o, a, r, e, t))
62
+ dead[i] = (e + t).clip(max=1).item()
63
+
64
+ num_episodes += sum(dead)
65
+
66
+ can_stop = num_to_collect.can_stop(num_steps, num_episodes)
67
+
68
+ count_dead = 0
69
+ for i in range(num_envs):
70
+ # Store incomplete episodes only when reset_every_collect is set to False (train)
71
+ add_to_dataset = dead[i] or (can_stop and not reset_every_collect)
72
+ if add_to_dataset:
73
+ info = {"final_observation": infos["final_observation"][count_dead]} if dead[i] else {}
74
+ ep = Episode(*(torch.cat(x, dim=0) for x in zip(*buffer[i])), info).to("cpu")
75
+ if episode_ids[i] is not None:
76
+ ep = dataset.load_episode(episode_ids[i]) + ep
77
+ episode_ids[i] = dataset.add_episode(ep, episode_id=episode_ids[i])
78
+
79
+ if dead[i]:
80
+ to_log.append(
81
+ {
82
+ f"{dataset.name}/episode_id": episode_ids[i],
83
+ **ep.compute_metrics(),
84
+ }
85
+ )
86
+ buffer[i] = []
87
+ episode_ids[i] = None
88
+ pbar.update(1 if num_to_collect.episodes is not None else 0)
89
+
90
+ count_dead += dead[i]
91
+
92
+ if can_stop:
93
+ pbar.close()
94
+ metrics = {
95
+ "num_steps": dataset.num_steps,
96
+ "counts/rew_-1": dataset.counts_rew[0],
97
+ "counts/rew__0": dataset.counts_rew[1],
98
+ "counts/rew_+1": dataset.counts_rew[2],
99
+ "counts/end_0": dataset.counts_end[0],
100
+ "counts/end_1": dataset.counts_end[1],
101
+ }
102
+ to_log.append({f"{dataset.name}/{k}": v for k, v in metrics.items()})
103
+ num_to_collect = yield to_log
104
+ setup_new_collect()
105
+ if reset_every_collect:
106
+ reset()
107
+
108
+
109
+ @dataclass
110
+ class NumToCollect:
111
+ steps: Optional[int] = None
112
+ episodes: Optional[int] = None
113
+
114
+ def __post_init__(self) -> None:
115
+ assert (self.steps is None) != (self.episodes is None)
116
+
117
+ def can_stop(self, num_steps: int, num_episodes: int) -> bool:
118
+ return num_steps >= self.steps if self.steps is not None else num_episodes >= self.episodes
119
+
120
+ @property
121
+ def unit(self) -> str:
122
+ return "steps" if self.steps is not None else "eps"
123
+
124
+ @property
125
+ def total(self) -> int:
126
+ return self.steps if self.steps is not None else self.episodes
src/coroutines/env_loop.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Generator, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.distributions.categorical import Categorical
7
+
8
+ from . import coroutine
9
+ from envs import TorchEnv, WorldModelEnv
10
+
11
+
12
+ @coroutine
13
+ def make_env_loop(
14
+ env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0
15
+ ) -> Generator[Tuple[torch.Tensor, ...], int, None]:
16
+ num_steps = yield
17
+
18
+ hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
19
+ cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
20
+
21
+ seed = random.randint(0, 2**31 - 1)
22
+ obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)])
23
+
24
+ while True:
25
+ hx, cx = hx.detach(), cx.detach()
26
+ all_ = []
27
+ infos = []
28
+ n = 0
29
+
30
+ while n < num_steps:
31
+ logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx))
32
+ act = Categorical(logits=logits_act).sample()
33
+
34
+ if random.random() < epsilon:
35
+ act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device)
36
+
37
+ next_obs, rew, end, trunc, info = env.step(act)
38
+
39
+ if n > 0:
40
+ val_bootstrap = val.detach().clone()
41
+ if dead.any():
42
+ val_bootstrap[dead] = val_final_obs
43
+ all_[-1][-1] = val_bootstrap
44
+
45
+ dead = torch.logical_or(end, trunc)
46
+
47
+ if dead.any():
48
+ with torch.no_grad():
49
+ _, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead]))
50
+ reset_gate = 1 - dead.float().unsqueeze(1)
51
+ hx = hx * reset_gate
52
+ cx = cx * reset_gate
53
+ if "burnin_obs" in info:
54
+ burnin_obs = info["burnin_obs"]
55
+ for i in range(burnin_obs.size(1)):
56
+ _, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead]))
57
+
58
+ all_.append([obs, act, rew, end, trunc, logits_act, val, None])
59
+ infos.append(info)
60
+
61
+ obs = next_obs
62
+ n += 1
63
+
64
+ with torch.no_grad():
65
+ _, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx
66
+
67
+ if dead.any():
68
+ val_bootstrap[dead] = val_final_obs
69
+
70
+ all_[-1][-1] = val_bootstrap
71
+
72
+ all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_))
73
+
74
+ num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos
src/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .batch import Batch
2
+ from .batch_sampler import BatchSampler
3
+ from .dataset import Dataset, GameHdf5Dataset
4
+ from .episode import Episode
5
+ from .segment import Segment, SegmentId
6
+ from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
src/data/batch.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List
4
+
5
+ import torch
6
+
7
+ from .segment import SegmentId
8
+
9
+
10
+ @dataclass
11
+ class Batch:
12
+ obs: torch.ByteTensor
13
+ act: torch.LongTensor
14
+ rew: torch.FloatTensor
15
+ end: torch.LongTensor
16
+ trunc: torch.LongTensor
17
+ mask_padding: torch.BoolTensor
18
+ info: List[Dict[str, Any]]
19
+ segment_ids: List[SegmentId]
20
+
21
+ def pin_memory(self) -> Batch:
22
+ return Batch(**{k: v if k in ("segment_ids", "info") else v.pin_memory() for k, v in self.__dict__.items()})
23
+
24
+ def to(self, device: torch.device) -> Batch:
25
+ return Batch(**{k: v if k in ("segment_ids", "info") else v.to(device) for k, v in self.__dict__.items()})
src/data/batch_sampler.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator, List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .dataset import GameHdf5Dataset, Dataset
7
+ from .segment import SegmentId
8
+
9
+
10
+ class BatchSampler(torch.utils.data.Sampler):
11
+ def __init__(
12
+ self,
13
+ dataset: Dataset,
14
+ rank: int,
15
+ world_size: int,
16
+ batch_size: int,
17
+ seq_length: int,
18
+ sample_weights: Optional[List[float]] = None,
19
+ can_sample_beyond_end: bool = False,
20
+ autoregressive_obs: int = None,
21
+ initial_num_consecutive_page_count: int = 1
22
+ ) -> None:
23
+ super().__init__(dataset)
24
+ assert isinstance(dataset, (Dataset, GameHdf5Dataset))
25
+ self.dataset = dataset
26
+ self.rank = rank
27
+ self.world_size = world_size
28
+ self.sample_weights = sample_weights
29
+ self.batch_size = batch_size
30
+ self.seq_length = seq_length
31
+ self.can_sample_beyond_end = can_sample_beyond_end
32
+ self.autoregressive_obs = autoregressive_obs
33
+ self.num_consecutive_batches = initial_num_consecutive_page_count
34
+
35
+ def __len__(self):
36
+ raise NotImplementedError
37
+
38
+ def __iter__(self) -> Generator[List[SegmentId], None, None]:
39
+ segments = None
40
+ current_iter = 0
41
+
42
+ while True:
43
+ if current_iter == 0:
44
+ segments = self.sample()
45
+ else:
46
+ segments = self.next(segments)
47
+
48
+ current_iter = (current_iter + 1) % self.num_consecutive_batches
49
+ yield segments
50
+
51
+ def next(self, segments: List[SegmentId]):
52
+ return [
53
+ SegmentId(segment.episode_id, segment.stop, segment.stop + self.autoregressive_obs, False)
54
+ for segment in segments
55
+ ]
56
+
57
+ def sample(self) -> List[SegmentId]:
58
+ total_length = self.seq_length + (self.num_consecutive_batches - 1) * self.autoregressive_obs
59
+
60
+ num_episodes = self.dataset.num_episodes
61
+
62
+ if (self.sample_weights is None) or num_episodes < len(self.sample_weights):
63
+ weights = self.dataset.lengths / self.dataset.num_steps
64
+ else:
65
+ weights = self.sample_weights
66
+ num_weights = len(self.sample_weights)
67
+ assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1
68
+ sizes = [
69
+ num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1)
70
+ for i in range(num_weights)
71
+ ]
72
+ weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)]
73
+
74
+ episodes_partition = np.arange(self.rank, num_episodes, self.world_size)
75
+ episode_lengths = self.dataset.lengths[episodes_partition]
76
+ valid_mask = episode_lengths > total_length # valid episodes must be long enough for autoregressvie generation
77
+ episodes_partition = episodes_partition[valid_mask]
78
+
79
+ weights = np.array(weights[self.rank::self.world_size])
80
+ weights = weights[valid_mask]
81
+
82
+ max_eps = self.batch_size
83
+ episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum())
84
+ episode_ids = episode_ids.repeat(self.batch_size // max_eps)
85
+
86
+ # choose a random timestamp at the dataset
87
+ timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
88
+ # compute total context size + autoregressive generation length
89
+
90
+ # the stops of the first page can be at most the length of the training example minus the autoregressive generation frames in the next pages
91
+ stops = np.minimum(
92
+ self.dataset.lengths[episode_ids] - (self.num_consecutive_batches - 1) * self.seq_length,
93
+ timesteps + 1 + np.random.randint(0, total_length, len(timesteps))
94
+ )
95
+ # stops must be longer than the initial context + first page prediction size
96
+ stops = np.maximum(stops, self.seq_length)
97
+ # starts is stops minus the initial context and the first page prediction size
98
+ starts = stops - self.seq_length
99
+
100
+ return [SegmentId(*x, True) for x in zip(episode_ids, starts, stops)]
src/data/dataset.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import multiprocessing as mp
3
+ from pathlib import Path
4
+ import shutil
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import h5py
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset as TorchDataset
12
+
13
+ from .episode import Episode
14
+ from .segment import Segment, SegmentId
15
+ from .utils import make_segment
16
+ from utils import StateDictMixin
17
+
18
+
19
+ class Dataset(StateDictMixin, TorchDataset):
20
+ def __init__(
21
+ self,
22
+ directory: Path,
23
+ dataset_full_res: Optional[TorchDataset],
24
+ name: Optional[str] = None,
25
+ cache_in_ram: bool = False,
26
+ use_manager: bool = False,
27
+ save_on_disk: bool = True,
28
+ ) -> None:
29
+ super().__init__()
30
+
31
+ # State
32
+ self.is_static = False
33
+ self.num_episodes = None
34
+ self.num_steps = None
35
+ self.start_idx = None
36
+ self.lengths = None
37
+ self.counter_rew = None
38
+ self.counter_end = None
39
+
40
+ self._directory = Path(directory).expanduser()
41
+ self._name = name if name is not None else self._directory.stem
42
+ self._cache_in_ram = cache_in_ram
43
+ self._save_on_disk = save_on_disk
44
+ self._default_path = self._directory / "info.pt"
45
+ self._cache = mp.Manager().dict() if use_manager else {}
46
+ self._reset()
47
+
48
+ self._dataset_full_res = dataset_full_res
49
+
50
+ def __len__(self) -> int:
51
+ return self.num_steps
52
+
53
+ def __getitem__(self, segment_id: SegmentId) -> Segment:
54
+ episode = self.load_episode(segment_id.episode_id)
55
+ segment = make_segment(episode, segment_id, should_pad=True)
56
+ if self._dataset_full_res is not None:
57
+ segment_id_full_res = SegmentId(episode.info["original_file_id"], segment_id.start, segment_id.stop, segment_id.is_first_batch)
58
+ segment.info["full_res"] = self._dataset_full_res[segment_id_full_res].obs
59
+ elif "full_res" in segment.info:
60
+ segment.info["full_res"] = segment.info["full_res"][segment_id.start:segment_id.stop]
61
+ return segment
62
+
63
+ def __str__(self) -> str:
64
+ return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps."
65
+
66
+ @property
67
+ def name(self) -> str:
68
+ return self._name
69
+
70
+ @property
71
+ def counts_rew(self) -> List[int]:
72
+ return [self.counter_rew[r] for r in [-1, 0, 1]]
73
+
74
+ @property
75
+ def counts_end(self) -> List[int]:
76
+ return [self.counter_end[e] for e in [0, 1]]
77
+
78
+ def _reset(self) -> None:
79
+ self.num_episodes = 0
80
+ self.num_steps = 0
81
+ self.start_idx = np.array([], dtype=np.int64)
82
+ self.lengths = np.array([], dtype=np.int64)
83
+ self.counter_rew = Counter()
84
+ self.counter_end = Counter()
85
+ self._cache.clear()
86
+
87
+ def clear(self) -> None:
88
+ self.assert_not_static()
89
+ if self._directory.is_dir():
90
+ shutil.rmtree(self._directory)
91
+ self._reset()
92
+
93
+ def load_episode(self, episode_id: int) -> Episode:
94
+ if self._cache_in_ram and episode_id in self._cache:
95
+ episode = self._cache[episode_id]
96
+ else:
97
+ episode = Episode.load(self._get_episode_path(episode_id))
98
+ if self._cache_in_ram:
99
+ self._cache[episode_id] = episode
100
+ return episode
101
+
102
+ def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
103
+ self.assert_not_static()
104
+ episode = episode.to("cpu")
105
+
106
+ if episode_id is None:
107
+ episode_id = self.num_episodes
108
+ self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
109
+ self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
110
+ self.num_steps += len(episode)
111
+ self.num_episodes += 1
112
+
113
+ else:
114
+ assert episode_id < self.num_episodes
115
+ old_episode = self.load_episode(episode_id)
116
+ incr_num_steps = len(episode) - len(old_episode)
117
+ self.lengths[episode_id] = len(episode)
118
+ self.start_idx[episode_id + 1 :] += incr_num_steps
119
+ self.num_steps += incr_num_steps
120
+ self.counter_rew.subtract(old_episode.rew.sign().tolist())
121
+ self.counter_end.subtract(old_episode.end.tolist())
122
+
123
+ self.counter_rew.update(episode.rew.sign().tolist())
124
+ self.counter_end.update(episode.end.tolist())
125
+
126
+ if self._save_on_disk:
127
+ episode.save(self._get_episode_path(episode_id))
128
+
129
+ if self._cache_in_ram:
130
+ self._cache[episode_id] = episode
131
+
132
+ return episode_id
133
+
134
+ def _get_episode_path(self, episode_id: int) -> Path:
135
+ n = 3 # number of hierarchies
136
+ powers = np.arange(n)
137
+ subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers
138
+ subfolders = [int(x) for x in subfolders[::-1]]
139
+ subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)])
140
+ return self._directory / subfolders / f"{episode_id}.pt"
141
+
142
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
143
+ super().load_state_dict(state_dict)
144
+ self._cache.clear()
145
+
146
+ def assert_not_static(self) -> None:
147
+ assert not self.is_static, "Trying to modify a static dataset."
148
+
149
+ def save_to_default_path(self) -> None:
150
+ self._default_path.parent.mkdir(exist_ok=True, parents=True)
151
+ torch.save(self.state_dict(), self._default_path)
152
+
153
+ def load_from_default_path(self) -> None:
154
+ print(self._default_path)
155
+ if self._default_path.is_file():
156
+ self.load_state_dict(torch.load(self._default_path, weights_only=False))
157
+
158
+
159
+ class GameHdf5Dataset(StateDictMixin, TorchDataset):
160
+ def __init__(self, directory: Path) -> None:
161
+ super().__init__()
162
+ filenames = sorted(Path(directory).rglob("*.hdf5"), key=lambda x: int(x.stem.split("_")[-1]))
163
+ self._filenames = {f"{x.parent.name}/{x.name}": x for x in filenames}
164
+
165
+ self._length_one_episode = self._episode_lengths(self._filenames)
166
+
167
+ self.num_episodes = len(self._filenames)
168
+
169
+ self.num_steps = sum(list(self._length_one_episode.values()))
170
+ self.lengths = np.array(list(self._length_one_episode.values()), dtype=np.int64)
171
+
172
+ def _episode_lengths(self, filenames):
173
+ length_one_episode = {}
174
+
175
+ for filename in filenames:
176
+ with h5py.File(filenames[filename], "r") as f:
177
+ keys = f.keys()
178
+ max_frame_index = max(int(key[len('frame_'):-len('_x')]) for key in keys if key.endswith('_x') and key.startswith('frame_'))
179
+ length_one_episode[filename] = max_frame_index + 1
180
+
181
+ return length_one_episode
182
+
183
+ def __len__(self) -> int:
184
+ return self.num_steps
185
+
186
+ def save_to_default_path(self) -> None:
187
+ pass
188
+
189
+ def __getitem__(self, segment_id: SegmentId) -> Segment:
190
+ episode_length = self._length_one_episode[segment_id.episode_id]
191
+ assert segment_id.start < episode_length and segment_id.stop > 0 and segment_id.start < segment_id.stop
192
+
193
+ pad_len_right = max(0, segment_id.stop - episode_length)
194
+ pad_len_left = max(0, -segment_id.start)
195
+
196
+ start = max(0, segment_id.start)
197
+ stop = min(episode_length, segment_id.stop)
198
+ mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
199
+
200
+ #print(self._filenames[segment_id.episode_id])
201
+ with h5py.File(self._filenames[segment_id.episode_id], "r") as f:
202
+ obs = torch.stack([torch.tensor(f[f"frame_{i}_x"][:]).flip(2).permute(2, 0, 1).div(255).mul(2).sub(1) for i in range(start, stop)])
203
+ act = torch.tensor(np.array([f[f"frame_{i}_y"][:] for i in range(start, stop)]))
204
+
205
+ def pad(x):
206
+ right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
207
+ return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
208
+
209
+ obs = pad(obs)
210
+ act = pad(act)
211
+ rew = torch.zeros(obs.size(0))
212
+ end = torch.zeros(obs.size(0), dtype=torch.uint8)
213
+ trunc = torch.zeros(obs.size(0), dtype=torch.uint8)
214
+ return Segment(obs, act, rew, end, trunc, mask_padding, info={}, id=SegmentId(segment_id.episode_id, start, stop, segment_id.is_first_batch))
215
+
216
+ def load_episode(self, episode_id: int) -> Episode: # used by DatasetTraverser
217
+ episode_length = self._length_one_episode[episode_id]
218
+ s = self[SegmentId(episode_id, 0, episode_length, None)]
219
+ return Episode(s.obs, s.act, s.rew, s.end, s.trunc, s.info)
src/data/episode.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class Episode:
11
+ obs: torch.FloatTensor
12
+ act: torch.LongTensor
13
+ rew: torch.FloatTensor
14
+ end: torch.ByteTensor
15
+ trunc: torch.ByteTensor
16
+ info: Dict[str, Any]
17
+
18
+ def __len__(self) -> int:
19
+ return self.obs.size(0)
20
+
21
+ def __add__(self, other: Episode) -> Episode:
22
+ assert self.dead.sum() == 0
23
+ d = {k: torch.cat((v, other.__dict__[k]), dim=0) for k, v in self.__dict__.items() if k != "info"}
24
+ return Episode(**d, info=merge_info(self.info, other.info))
25
+
26
+ def to(self, device) -> Episode:
27
+ return Episode(**{k: v.to(device) if k != "info" else v for k, v in self.__dict__.items()})
28
+
29
+ @property
30
+ def dead(self) -> torch.ByteTensor:
31
+ return (self.end + self.trunc).clip(max=1)
32
+
33
+ def compute_metrics(self) -> Dict[str, Any]:
34
+ return {"length": len(self), "return": self.rew.sum().item()}
35
+
36
+ @classmethod
37
+ def load(cls, path: Path, map_location: Optional[torch.device] = None) -> Episode:
38
+ return cls(
39
+ **{
40
+ k: v.div(255).mul(2).sub(1) if k == "obs" else v
41
+ for k, v in torch.load(Path(path), map_location=map_location).items()
42
+ }
43
+ )
44
+
45
+ def save(self, path: Path) -> None:
46
+ path = Path(path)
47
+ path.parent.mkdir(parents=True, exist_ok=True)
48
+ d = {k: v.add(1).div(2).mul(255).byte() if k == "obs" else v for k, v in self.__dict__.items()}
49
+ torch.save(d, path.with_suffix(".tmp"))
50
+ path.with_suffix(".tmp").rename(path)
51
+
52
+
53
+ def merge_info(info_a, info_b):
54
+ keys_a = set(info_a)
55
+ keys_b = set(info_b)
56
+ intersection = keys_a & keys_b
57
+ info = {
58
+ **{k: info_a[k] for k in keys_a if k not in intersection},
59
+ **{k: info_b[k] for k in keys_b if k not in intersection},
60
+ **{k: torch.cat((info_a[k], info_b[k]), dim=0) for k in intersection},
61
+ }
62
+ return info
src/data/segment.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Union
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass
9
+ class SegmentId:
10
+ episode_id: Union[int, str]
11
+ start: int
12
+ stop: int
13
+ is_first_batch: bool
14
+
15
+
16
+ @dataclass
17
+ class Segment:
18
+ obs: torch.FloatTensor
19
+ act: torch.LongTensor
20
+ rew: torch.FloatTensor
21
+ end: torch.ByteTensor
22
+ trunc: torch.ByteTensor
23
+ mask_padding: torch.BoolTensor
24
+ info: Dict[str, Any]
25
+ id: SegmentId
26
+
27
+ @property
28
+ def effective_size(self):
29
+ return self.mask_padding.sum().item()
src/data/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Generator, List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from .batch import Batch
8
+ from .episode import Episode
9
+ from .segment import Segment, SegmentId
10
+
11
+
12
+ def collate_segments_to_batch(segments: List[Segment]) -> Batch:
13
+ attrs = ("obs", "act", "rew", "end", "trunc", "mask_padding")
14
+ stack = (torch.stack([getattr(s, x) for s in segments]) for x in attrs)
15
+ return Batch(*stack, [s.info for s in segments], [s.id for s in segments])
16
+
17
+
18
+ def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
19
+ if not (segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop):
20
+ print(f'Failed assertion because: start={segment_id.start}, stop={segment_id.stop}, len(episode)={len(episode)}')
21
+
22
+ assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
23
+ pad_len_right = max(0, segment_id.stop - len(episode))
24
+ pad_len_left = max(0, -segment_id.start)
25
+ assert pad_len_right == pad_len_left == 0 or should_pad
26
+
27
+ def pad(x):
28
+ right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
29
+ return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
30
+
31
+ start = max(0, segment_id.start)
32
+ stop = min(len(episode), segment_id.stop)
33
+ mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
34
+
35
+ return Segment(
36
+ pad(episode.obs[start:stop]),
37
+ pad(episode.act[start:stop]),
38
+ pad(episode.rew[start:stop]),
39
+ pad(episode.end[start:stop]),
40
+ pad(episode.trunc[start:stop]),
41
+ mask_padding,
42
+ info=episode.info,
43
+ id=SegmentId(segment_id.episode_id, start, stop, segment_id.is_first_batch),
44
+ )
45
+
46
+
47
+ class DatasetTraverser:
48
+ def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
49
+ self.dataset = dataset
50
+ self.batch_num_samples = batch_num_samples
51
+ self.chunk_size = chunk_size
52
+
53
+ def __len__(self):
54
+ return math.ceil(
55
+ sum(
56
+ [
57
+ math.ceil(self.dataset.lengths[episode_id] / self.chunk_size)
58
+ - int(self.dataset.lengths[episode_id] % self.chunk_size == 1)
59
+ for episode_id in range(self.dataset.num_episodes)
60
+ ]
61
+ )
62
+ / self.batch_num_samples
63
+ )
64
+
65
+ def __iter__(self) -> Generator[Batch, None, None]:
66
+ chunks = []
67
+ for episode_id in range(self.dataset.num_episodes):
68
+ episode = self.dataset.load_episode(episode_id)
69
+ segments = []
70
+ for i in range(math.ceil(len(episode) / self.chunk_size)):
71
+ start = i * self.chunk_size
72
+ stop = (i + 1) * self.chunk_size
73
+ segment = make_segment(
74
+ episode,
75
+ SegmentId(episode_id, start, stop, None),
76
+ should_pad=True,
77
+ )
78
+ segment_id_full_res = SegmentId(episode.info["original_file_id"], start, stop)
79
+ segment.info["full_res"] = self.dataset._dataset_full_res[segment_id_full_res].obs
80
+ chunks.append(segment)
81
+ if chunks[-1].effective_size < 2:
82
+ chunks.pop()
83
+
84
+ while len(chunks) >= self.batch_num_samples:
85
+ yield collate_segments_to_batch(chunks[: self.batch_num_samples])
86
+ chunks = chunks[self.batch_num_samples :]
87
+
88
+ if len(chunks) > 0:
89
+ yield collate_segments_to_batch(chunks)
90
+