Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| """ | |
| Demo for Self-Forcing. | |
| """ | |
| import os | |
| import re | |
| import random | |
| import time | |
| import base64 | |
| import argparse | |
| import hashlib | |
| import subprocess | |
| import urllib.request | |
| from io import BytesIO | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from omegaconf import OmegaConf | |
| from flask import Flask, render_template, jsonify | |
| from flask_socketio import SocketIO, emit | |
| import queue | |
| from threading import Thread, Event | |
| from pipeline import CausalInferencePipeline | |
| from demo_utils.constant import ZERO_VAE_CACHE | |
| from demo_utils.vae_block3 import VAEDecoderWrapper | |
| from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder | |
| from demo_utils.utils import generate_timestamp | |
| from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation | |
| # Parse arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--port', type=int, default=5001) | |
| parser.add_argument('--host', type=str, default='0.0.0.0') | |
| parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt') | |
| parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml') | |
| parser.add_argument('--trt', action='store_true') | |
| args = parser.parse_args() | |
| print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB') | |
| low_memory = get_cuda_free_memory_gb(gpu) < 40 | |
| # Load models | |
| config = OmegaConf.load(args.config_path) | |
| default_config = OmegaConf.load("configs/default_config.yaml") | |
| config = OmegaConf.merge(default_config, config) | |
| text_encoder = WanTextEncoder() | |
| # Global variables for dynamic model switching | |
| current_vae_decoder = None | |
| current_use_taehv = False | |
| fp8_applied = False | |
| torch_compile_applied = False | |
| global frame_number | |
| frame_number = 0 | |
| anim_name = "" | |
| frame_rate = 6 | |
| def initialize_vae_decoder(use_taehv=False, use_trt=False): | |
| """Initialize VAE decoder based on the selected option""" | |
| global current_vae_decoder, current_use_taehv | |
| if use_trt: | |
| from demo_utils.vae import VAETRTWrapper | |
| current_vae_decoder = VAETRTWrapper() | |
| return current_vae_decoder | |
| if use_taehv: | |
| from demo_utils.taehv import TAEHV | |
| # Check if taew2_1.pth exists in checkpoints folder, download if missing | |
| taehv_checkpoint_path = "checkpoints/taew2_1.pth" | |
| if not os.path.exists(taehv_checkpoint_path): | |
| print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...") | |
| os.makedirs("checkpoints", exist_ok=True) | |
| download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth" | |
| try: | |
| urllib.request.urlretrieve(download_url, taehv_checkpoint_path) | |
| print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}") | |
| except Exception as e: | |
| print(f"Failed to download taew2_1.pth: {e}") | |
| raise | |
| class DotDict(dict): | |
| __getattr__ = dict.__getitem__ | |
| __setattr__ = dict.__setitem__ | |
| class TAEHVDiffusersWrapper(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.dtype = torch.float16 | |
| self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype) | |
| self.config = DotDict(scaling_factor=1.0) | |
| def decode(self, latents, return_dict=None): | |
| # n, c, t, h, w = latents.shape | |
| # low-memory, set parallel=True for faster + higher memory | |
| return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1) | |
| current_vae_decoder = TAEHVDiffusersWrapper() | |
| else: | |
| current_vae_decoder = VAEDecoderWrapper() | |
| vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu") | |
| decoder_state_dict = {} | |
| for key, value in vae_state_dict.items(): | |
| if 'decoder.' in key or 'conv2' in key: | |
| decoder_state_dict[key] = value | |
| current_vae_decoder.load_state_dict(decoder_state_dict) | |
| current_vae_decoder.eval() | |
| current_vae_decoder.to(dtype=torch.float16) | |
| current_vae_decoder.requires_grad_(False) | |
| current_vae_decoder.to(gpu) | |
| current_use_taehv = use_taehv | |
| print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}") | |
| return current_vae_decoder | |
| # Initialize with default VAE | |
| vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt) | |
| transformer = WanDiffusionWrapper(is_causal=True) | |
| state_dict = torch.load(args.checkpoint_path, map_location="cpu") | |
| transformer.load_state_dict(state_dict['generator_ema']) | |
| text_encoder.eval() | |
| transformer.eval() | |
| transformer.to(dtype=torch.float16) | |
| text_encoder.to(dtype=torch.bfloat16) | |
| text_encoder.requires_grad_(False) | |
| transformer.requires_grad_(False) | |
| pipeline = CausalInferencePipeline( | |
| config, | |
| device=gpu, | |
| generator=transformer, | |
| text_encoder=text_encoder, | |
| vae=vae_decoder | |
| ) | |
| if low_memory: | |
| DynamicSwapInstaller.install_model(text_encoder, device=gpu) | |
| else: | |
| text_encoder.to(gpu) | |
| transformer.to(gpu) | |
| # Flask and SocketIO setup | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = 'frontend_buffered_demo' | |
| socketio = SocketIO(app, cors_allowed_origins="*") | |
| generation_active = False | |
| stop_event = Event() | |
| frame_send_queue = queue.Queue() | |
| sender_thread = None | |
| models_compiled = False | |
| def tensor_to_base64_frame(frame_tensor): | |
| """Convert a single frame tensor to base64 image string.""" | |
| global frame_number, anim_name | |
| # Clamp and normalize to 0-255 | |
| frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5 | |
| frame = frame.to(torch.uint8).cpu().numpy() | |
| # CHW -> HWC | |
| if len(frame.shape) == 3: | |
| frame = np.transpose(frame, (1, 2, 0)) | |
| # Convert to PIL Image | |
| if frame.shape[2] == 3: # RGB | |
| image = Image.fromarray(frame, 'RGB') | |
| else: # Handle other formats | |
| image = Image.fromarray(frame) | |
| # Convert to base64 | |
| buffer = BytesIO() | |
| image.save(buffer, format='JPEG', quality=100) | |
| if not os.path.exists("./images/%s" % anim_name): | |
| os.makedirs("./images/%s" % anim_name) | |
| frame_number += 1 | |
| image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number)) | |
| img_str = base64.b64encode(buffer.getvalue()).decode() | |
| return f"data:image/jpeg;base64,{img_str}" | |
| def frame_sender_worker(): | |
| """Background thread that processes frame send queue non-blocking.""" | |
| global frame_send_queue, generation_active, stop_event | |
| print("📡 Frame sender thread started") | |
| while True: | |
| frame_data = None | |
| try: | |
| # Get frame data from queue | |
| frame_data = frame_send_queue.get(timeout=1.0) | |
| if frame_data is None: # Shutdown signal | |
| frame_send_queue.task_done() # Mark shutdown signal as done | |
| break | |
| frame_tensor, frame_index, block_index, job_id = frame_data | |
| # Convert tensor to base64 | |
| base64_frame = tensor_to_base64_frame(frame_tensor) | |
| # Send via SocketIO | |
| try: | |
| socketio.emit('frame_ready', { | |
| 'data': base64_frame, | |
| 'frame_index': frame_index, | |
| 'block_index': block_index, | |
| 'job_id': job_id | |
| }) | |
| except Exception as e: | |
| print(f"⚠️ Failed to send frame {frame_index}: {e}") | |
| frame_send_queue.task_done() | |
| except queue.Empty: | |
| # Check if we should continue running | |
| if not generation_active and frame_send_queue.empty(): | |
| break | |
| except Exception as e: | |
| print(f"❌ Frame sender error: {e}") | |
| # Make sure to mark task as done even if there's an error | |
| if frame_data is not None: | |
| try: | |
| frame_send_queue.task_done() | |
| except Exception as e: | |
| print(f"❌ Failed to mark frame task as done: {e}") | |
| break | |
| print("📡 Frame sender thread stopped") | |
| def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False): | |
| """Generate video and push frames immediately to frontend.""" | |
| global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name | |
| try: | |
| generation_active = True | |
| stop_event.clear() | |
| job_id = generate_timestamp() | |
| # Start frame sender thread if not already running | |
| if sender_thread is None or not sender_thread.is_alive(): | |
| sender_thread = Thread(target=frame_sender_worker, daemon=True) | |
| sender_thread.start() | |
| # Emit progress updates | |
| def emit_progress(message, progress): | |
| try: | |
| socketio.emit('progress', { | |
| 'message': message, | |
| 'progress': progress, | |
| 'job_id': job_id | |
| }) | |
| except Exception as e: | |
| print(f"❌ Failed to emit progress: {e}") | |
| emit_progress('Starting generation...', 0) | |
| # Handle VAE decoder switching | |
| if use_taehv != current_use_taehv: | |
| emit_progress('Switching VAE decoder...', 2) | |
| print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}") | |
| current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv) | |
| # Update pipeline with new VAE decoder | |
| pipeline.vae = current_vae_decoder | |
| # Handle FP8 quantization | |
| if enable_fp8 and not fp8_applied: | |
| emit_progress('Applying FP8 quantization...', 3) | |
| print("🔧 Applying FP8 quantization to transformer") | |
| from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor | |
| quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())) | |
| fp8_applied = True | |
| # Text encoding | |
| emit_progress('Encoding text prompt...', 8) | |
| conditional_dict = text_encoder(text_prompts=[prompt]) | |
| for key, value in conditional_dict.items(): | |
| conditional_dict[key] = value.to(dtype=torch.float16) | |
| if low_memory: | |
| gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5 | |
| move_model_to_device_with_memory_preservation( | |
| text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation) | |
| # Handle torch.compile if enabled | |
| torch_compile_applied = enable_torch_compile | |
| if enable_torch_compile and not models_compiled: | |
| # Compile transformer and decoder | |
| transformer.compile(mode="max-autotune-no-cudagraphs") | |
| if not current_use_taehv and not low_memory and not args.trt: | |
| current_vae_decoder.compile(mode="max-autotune-no-cudagraphs") | |
| # Initialize generation | |
| emit_progress('Initializing generation...', 12) | |
| rnd = torch.Generator(gpu).manual_seed(seed) | |
| # all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16) | |
| pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu) | |
| pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu) | |
| noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd) | |
| # Generation parameters | |
| num_blocks = 7 | |
| current_start_frame = 0 | |
| num_input_frames = 0 | |
| all_num_frames = [pipeline.num_frame_per_block] * num_blocks | |
| if current_use_taehv: | |
| vae_cache = None | |
| else: | |
| vae_cache = ZERO_VAE_CACHE | |
| for i in range(len(vae_cache)): | |
| vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16) | |
| total_frames_sent = 0 | |
| generation_start_time = time.time() | |
| emit_progress('Generating frames... (frontend handles timing)', 15) | |
| for idx, current_num_frames in enumerate(all_num_frames): | |
| if not generation_active or stop_event.is_set(): | |
| break | |
| progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15 | |
| # Special message for first block with torch.compile | |
| if idx == 0 and torch_compile_applied and not models_compiled: | |
| emit_progress( | |
| f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress) | |
| print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}") | |
| models_compiled = True | |
| else: | |
| emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress) | |
| print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}") | |
| block_start_time = time.time() | |
| noisy_input = noise[:, current_start_frame - | |
| num_input_frames:current_start_frame + current_num_frames - num_input_frames] | |
| # Denoising loop | |
| denoising_start = time.time() | |
| for index, current_timestep in enumerate(pipeline.denoising_step_list): | |
| if not generation_active or stop_event.is_set(): | |
| break | |
| timestep = torch.ones([1, current_num_frames], device=noise.device, | |
| dtype=torch.int64) * current_timestep | |
| if index < len(pipeline.denoising_step_list) - 1: | |
| _, denoised_pred = transformer( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| kv_cache=pipeline.kv_cache1, | |
| crossattn_cache=pipeline.crossattn_cache, | |
| current_start=current_start_frame * pipeline.frame_seq_length | |
| ) | |
| next_timestep = pipeline.denoising_step_list[index + 1] | |
| noisy_input = pipeline.scheduler.add_noise( | |
| denoised_pred.flatten(0, 1), | |
| torch.randn_like(denoised_pred.flatten(0, 1)), | |
| next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long) | |
| ).unflatten(0, denoised_pred.shape[:2]) | |
| else: | |
| _, denoised_pred = transformer( | |
| noisy_image_or_video=noisy_input, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| kv_cache=pipeline.kv_cache1, | |
| crossattn_cache=pipeline.crossattn_cache, | |
| current_start=current_start_frame * pipeline.frame_seq_length | |
| ) | |
| if not generation_active or stop_event.is_set(): | |
| break | |
| denoising_time = time.time() - denoising_start | |
| print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s") | |
| # Record output | |
| # all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred | |
| # Update KV cache for next block | |
| if idx != len(all_num_frames) - 1: | |
| transformer( | |
| noisy_image_or_video=denoised_pred, | |
| conditional_dict=conditional_dict, | |
| timestep=torch.zeros_like(timestep), | |
| kv_cache=pipeline.kv_cache1, | |
| crossattn_cache=pipeline.crossattn_cache, | |
| current_start=current_start_frame * pipeline.frame_seq_length, | |
| ) | |
| # Decode to pixels and send frames immediately | |
| print(f"🎨 Decoding block {idx+1} to pixels...") | |
| decode_start = time.time() | |
| if args.trt: | |
| all_current_pixels = [] | |
| for i in range(denoised_pred.shape[1]): | |
| is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \ | |
| torch.tensor(0.0).cuda().half() | |
| outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache) | |
| # outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache) | |
| current_pixels, vae_cache = outputs[0], outputs[1:] | |
| print(current_pixels.max(), current_pixels.min()) | |
| all_current_pixels.append(current_pixels.clone()) | |
| pixels = torch.cat(all_current_pixels, dim=1) | |
| if idx == 0: | |
| pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block | |
| else: | |
| if current_use_taehv: | |
| if vae_cache is None: | |
| vae_cache = denoised_pred | |
| else: | |
| denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1) | |
| vae_cache = denoised_pred[:, -3:, :, :, :] | |
| pixels = current_vae_decoder.decode(denoised_pred) | |
| print(f"denoised_pred shape: {denoised_pred.shape}") | |
| print(f"pixels shape: {pixels.shape}") | |
| if idx == 0: | |
| pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block | |
| else: | |
| pixels = pixels[:, 12:, :, :, :] | |
| else: | |
| pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache) | |
| if idx == 0: | |
| pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block | |
| decode_time = time.time() - decode_start | |
| print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s") | |
| # Queue frames for non-blocking sending | |
| block_frames = pixels.shape[1] | |
| print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...") | |
| queue_start = time.time() | |
| for frame_idx in range(block_frames): | |
| if not generation_active or stop_event.is_set(): | |
| break | |
| frame_tensor = pixels[0, frame_idx].cpu() | |
| # Queue frame data in non-blocking way | |
| frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id)) | |
| total_frames_sent += 1 | |
| queue_time = time.time() - queue_start | |
| block_time = time.time() - block_start_time | |
| print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)") | |
| current_start_frame += current_num_frames | |
| generation_time = time.time() - generation_start_time | |
| print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending") | |
| # Wait for all frames to be sent before completing | |
| emit_progress('Waiting for all frames to be sent...', 97) | |
| print("⏳ Waiting for all frames to be sent...") | |
| frame_send_queue.join() # Wait for all queued frames to be processed | |
| print("✅ All frames sent successfully!") | |
| generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate ) | |
| # Final progress update | |
| emit_progress('Generation complete!', 100) | |
| try: | |
| socketio.emit('generation_complete', { | |
| 'message': 'Video generation completed!', | |
| 'total_frames': total_frames_sent, | |
| 'generation_time': f"{generation_time:.2f}s", | |
| 'job_id': job_id | |
| }) | |
| except Exception as e: | |
| print(f"❌ Failed to emit generation complete: {e}") | |
| except Exception as e: | |
| print(f"❌ Generation failed: {e}") | |
| try: | |
| socketio.emit('error', { | |
| 'message': f'Generation failed: {str(e)}', | |
| 'job_id': job_id | |
| }) | |
| except Exception as e: | |
| print(f"❌ Failed to emit error: {e}") | |
| finally: | |
| generation_active = False | |
| stop_event.set() | |
| # Clean up sender thread | |
| try: | |
| frame_send_queue.put(None) | |
| except Exception as e: | |
| print(f"❌ Failed to put None in frame_send_queue: {e}") | |
| def generate_mp4_from_images(image_directory, output_video_path, fps=24): | |
| """ | |
| Generate an MP4 video from a directory of images ordered alphabetically. | |
| :param image_directory: Path to the directory containing images. | |
| :param output_video_path: Path where the output MP4 will be saved. | |
| :param fps: Frames per second for the output video. | |
| """ | |
| global anim_name | |
| # Construct the ffmpeg command | |
| cmd = [ | |
| 'ffmpeg', | |
| '-framerate', str(fps), | |
| '-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary | |
| '-c:v', 'libx264', | |
| '-pix_fmt', 'yuv420p', | |
| output_video_path | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True) | |
| print(f"Video saved to {output_video_path}") | |
| except subprocess.CalledProcessError as e: | |
| print(f"An error occurred: {e}") | |
| def calculate_sha256(data): | |
| # Convert data to bytes if it's not already | |
| if isinstance(data, str): | |
| data = data.encode() | |
| # Calculate SHA-256 hash | |
| sha256_hash = hashlib.sha256(data).hexdigest() | |
| return sha256_hash | |
| # Socket.IO event handlers | |
| def handle_connect(): | |
| print('Client connected') | |
| emit('status', {'message': 'Connected to frontend-buffered demo server'}) | |
| def handle_disconnect(): | |
| print('Client disconnected') | |
| def handle_start_generation(data): | |
| global generation_active, frame_number, anim_name, frame_rate | |
| frame_number = 0 | |
| if generation_active: | |
| emit('error', {'message': 'Generation already in progress'}) | |
| return | |
| prompt = data.get('prompt', '') | |
| seed = data.get('seed', -1) | |
| if seed==-1: | |
| seed = random.randint(0, 2**32) | |
| # Extract words up to the first punctuation or newline | |
| words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else '' | |
| if not words_up_to_punctuation: | |
| words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip() | |
| # Calculate SHA-256 hash of the entire prompt | |
| sha256_hash = calculate_sha256(prompt) | |
| # Create anim_name with the extracted words and first 10 characters of the hash | |
| anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}" | |
| generation_active = True | |
| generation_start_time = time.time() | |
| enable_torch_compile = data.get('enable_torch_compile', False) | |
| enable_fp8 = data.get('enable_fp8', False) | |
| use_taehv = data.get('use_taehv', False) | |
| frame_rate = data.get('fps', 6) | |
| if not prompt: | |
| emit('error', {'message': 'Prompt is required'}) | |
| return | |
| # Start generation in background thread | |
| socketio.start_background_task(generate_video_stream, prompt, seed, | |
| enable_torch_compile, enable_fp8, use_taehv) | |
| emit('status', {'message': 'Generation started - frames will be sent immediately'}) | |
| def handle_stop_generation(): | |
| global generation_active, stop_event, frame_send_queue | |
| generation_active = False | |
| stop_event.set() | |
| # Signal sender thread to stop (will be processed after current frames) | |
| try: | |
| frame_send_queue.put(None) | |
| except Exception as e: | |
| print(f"❌ Failed to put None in frame_send_queue: {e}") | |
| emit('status', {'message': 'Generation stopped'}) | |
| # Web routes | |
| def index(): | |
| return render_template('demo.html') | |
| def api_status(): | |
| return jsonify({ | |
| 'generation_active': generation_active, | |
| 'free_vram_gb': get_cuda_free_memory_gb(gpu), | |
| 'fp8_applied': fp8_applied, | |
| 'torch_compile_applied': torch_compile_applied, | |
| 'current_use_taehv': current_use_taehv | |
| }) | |
| if __name__ == '__main__': | |
| print(f"🚀 Starting demo on http://{args.host}:{args.port}") | |
| socketio.run(app, host=args.host, port=args.port, debug=False) | |
