#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ MatrixGame V2 Engine This module handles the core rendering and model inference for the Matrix-Game V2 project. """ import os import logging import argparse import time import torch import numpy as np from PIL import Image import cv2 from omegaconf import OmegaConf from torchvision.transforms import v2 from diffusers.utils import load_image from typing import Dict, List, Tuple, Any, Optional, Union from huggingface_hub import snapshot_download from safetensors.torch import load_file # Matrix-Game V2 specific imports from pipeline import CausalInferenceStreamingPipeline from wan.vae.wanx_vae import get_wanx_vae_wrapper from demo_utils.vae_block3 import VAEDecoderWrapper from utils.misc import set_seed from utils.conditions import * from utils.wan_wrapper import WanDiffusionWrapper # Import utility functions from api_utils import ( visualize_controls, frame_to_jpeg, load_scene_frames, logger ) class MatrixGameEngine: """ Core engine for Matrix-Game V2 model inference and frame generation. """ def __init__(self, args: Optional[argparse.Namespace] = None): """ Initialize the Matrix-Game V2 engine with configuration parameters. Args: args: Optional parsed command line arguments for model configuration """ # Set default parameters if args not provided # V2 uses 352x640 as standard resolution self.frame_width = getattr(args, 'frame_width', 640) self.frame_height = getattr(args, 'frame_height', 352) self.fps = getattr(args, 'fps', 16) self.max_num_output_frames = getattr(args, 'max_num_output_frames', 90) # Reduced for real-time self.seed = getattr(args, 'seed', 0) self.config_path = getattr(args, 'config_path', 'configs/inference_yaml/inference_universal.yaml') self.checkpoint_path = getattr(args, 'checkpoint_path', '') self.pretrained_model_path = getattr(args, 'pretrained_model_path', 'Matrix-Game-2.0') # Initialize state self.frame_count = 0 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.weight_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 # Frame processing pipeline self.frame_process = v2.Compose([ v2.Resize(size=(self.frame_height, self.frame_width), antialias=True), v2.ToTensor(), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) # Cache scene initial frames self.scenes = { 'forest': load_scene_frames('forest', self.frame_width, self.frame_height), 'desert': load_scene_frames('desert', self.frame_width, self.frame_height), 'beach': load_scene_frames('beach', self.frame_width, self.frame_height), 'hills': load_scene_frames('hills', self.frame_width, self.frame_height), 'river': load_scene_frames('river', self.frame_width, self.frame_height), 'icy': load_scene_frames('icy', self.frame_width, self.frame_height), 'mushroom': load_scene_frames('mushroom', self.frame_width, self.frame_height), 'plain': load_scene_frames('plain', self.frame_width, self.frame_height) } # Add universal scene for V2 self.scenes['universal'] = load_scene_frames('universal', self.frame_width, self.frame_height) self.scenes['gta_drive'] = load_scene_frames('gta_drive', self.frame_width, self.frame_height) self.scenes['temple_run'] = load_scene_frames('temple_run', self.frame_width, self.frame_height) # Cache for preprocessed images and latents self.scene_latents = {} self.current_latent = None self.current_frame_idx = 0 # Initialize Matrix-Game V2 pipeline self.model_loaded = False if not torch.cuda.is_available(): error_msg = "CUDA is not available. Matrix-Game V2 requires an NVIDIA GPU with CUDA support." logger.error(error_msg) raise RuntimeError(error_msg) try: self._init_models() self.model_loaded = True logger.info("Matrix-Game V2 models loaded successfully") except Exception as e: error_msg = f"Failed to initialize Matrix-Game V2 models: {str(e)}" logger.error(error_msg) raise RuntimeError(error_msg) def _init_models(self): """Initialize Matrix-Game V2 models""" try: # Load configuration self.config = OmegaConf.load(self.config_path) # Initialize generator generator = WanDiffusionWrapper( **getattr(self.config, "model_kwargs", {}), is_causal=True) # Initialize VAE decoder current_vae_decoder = VAEDecoderWrapper() # Check if model exists locally, if not download from Hugging Face if not os.path.exists(self.pretrained_model_path) or not os.path.exists(os.path.join(self.pretrained_model_path, "Wan2.1_VAE.pth")): logger.info(f"Model not found at {self.pretrained_model_path}, downloading from Hugging Face...") try: # Download from Skywork/Matrix-Game-2.0 downloaded_path = snapshot_download( repo_id="Skywork/Matrix-Game-2.0", local_dir=self.pretrained_model_path ) logger.info(f"Successfully downloaded model to {downloaded_path}") except Exception as e: logger.error(f"Failed to download model from Hugging Face: {str(e)}") raise # Load VAE state dict vae_state_dict = torch.load(os.path.join(self.pretrained_model_path, "Wan2.1_VAE.pth"), map_location="cpu") decoder_state_dict = {} for key, value in vae_state_dict.items(): if 'decoder.' in key or 'conv2' in key: decoder_state_dict[key] = value current_vae_decoder.load_state_dict(decoder_state_dict) current_vae_decoder.to(self.device, torch.float16) current_vae_decoder.requires_grad_(False) current_vae_decoder.eval() # Use standard compilation mode for server deployment try: current_vae_decoder.compile(mode="reduce-overhead") except: logger.warning("VAE decoder compilation failed, continuing without compilation") # Initialize streaming pipeline for real-time generation self.pipeline = CausalInferenceStreamingPipeline(self.config, generator=generator, vae_decoder=current_vae_decoder) # Load checkpoint if provided if self.checkpoint_path and os.path.exists(self.checkpoint_path): logger.info("Loading checkpoint...") state_dict = load_file(self.checkpoint_path) self.pipeline.generator.load_state_dict(state_dict) self.pipeline = self.pipeline.to(device=self.device, dtype=self.weight_dtype) self.pipeline.vae_decoder.to(torch.float16) # Initialize VAE encoder vae = get_wanx_vae_wrapper(self.pretrained_model_path, torch.float16) vae.requires_grad_(False) vae.eval() self.vae = vae.to(self.device, self.weight_dtype) logger.info("Models loaded successfully") # Preprocess initial images for all scenes for scene_name, frames in self.scenes.items(): if frames and len(frames) > 0: # Prepare the first frame as initial latent self._prepare_scene_latent(scene_name, frames[0]) except Exception as e: logger.error(f"Error loading models: {str(e)}") raise def _resizecrop(self, image, th, tw): """Resize and crop image to target dimensions""" if isinstance(image, np.ndarray): image = Image.fromarray(image) w, h = image.size if h / w > th / tw: new_w = int(w) new_h = int(new_w * th / tw) else: new_h = int(h) new_w = int(new_h * tw / th) left = (w - new_w) / 2 top = (h - new_h) / 2 right = (w + new_w) / 2 bottom = (h + new_h) / 2 image = image.crop((left, top, right, bottom)) return image def _prepare_scene_latent(self, scene_name: str, frame: np.ndarray): """Prepare and cache latent for a scene""" try: # Convert to PIL if needed if isinstance(frame, np.ndarray): image = Image.fromarray(frame) else: image = frame # Resize and process image = self._resizecrop(image, self.frame_height, self.frame_width) processed = self.frame_process(image)[None, :, None, :, :].to(dtype=self.weight_dtype, device=self.device) # Encode to latent space padding_video = torch.zeros_like(processed).repeat(1, 1, 4 * (self.max_num_output_frames - 1), 1, 1) img_cond = torch.concat([processed, padding_video], dim=2) # Use tiling for memory efficiency tiler_kwargs = {"tiled": True, "tile_size": [44, 80], "tile_stride": [23, 38]} img_latent = self.vae.encode(img_cond, device=self.device, **tiler_kwargs).to(self.device) # Create mask mask_cond = torch.ones_like(img_latent) mask_cond[:, :, 1:] = 0 # Store preprocessed data self.scene_latents[scene_name] = { 'image': processed, 'latent': img_latent, 'mask': mask_cond, 'visual_context': self.vae.clip.encode_video(processed) } except Exception as e: logger.error(f"Error preparing latent for scene {scene_name}: {str(e)}") def generate_frame(self, scene_name: str, keyboard_condition: Optional[List] = None, mouse_condition: Optional[List] = None) -> bytes: """ Generate the next frame based on current conditions using Matrix-Game V2 model. Args: scene_name: Name of the current scene keyboard_condition: Keyboard input state mouse_condition: Mouse input state Returns: bytes: JPEG bytes of the frame """ # Check if model is loaded if not self.model_loaded: error_msg = "Model not loaded. Cannot generate frames." logger.error(error_msg) raise RuntimeError(error_msg) if not torch.cuda.is_available(): error_msg = "CUDA is no longer available. Cannot generate frames." logger.error(error_msg) raise RuntimeError(error_msg) try: # Map scene name to mode mode_map = { 'universal': 'universal', 'gta_drive': 'gta_drive', 'temple_run': 'templerun', 'templerun': 'templerun' } mode = mode_map.get(scene_name, 'universal') # Get cached latent or prepare new one if scene_name not in self.scene_latents: scene_frames = self.scenes.get(scene_name, self.scenes.get('universal', [])) if scene_frames: self._prepare_scene_latent(scene_name, scene_frames[0]) else: error_msg = f"No initial frames available for scene: {scene_name}" logger.error(error_msg) raise ValueError(error_msg) scene_data = self.scene_latents.get(scene_name) if not scene_data: error_msg = f"Failed to prepare latent for scene: {scene_name}" logger.error(error_msg) raise ValueError(error_msg) # Prepare conditions if keyboard_condition is None: keyboard_condition = [[0, 0, 0, 0, 0, 0]] if mouse_condition is None: mouse_condition = [[0, 0]] # Generate conditions for multiple frames (for streaming) num_frames = 5 # Generate 5 frames at a time for smoother playback # Create condition tensors keyboard_tensor = torch.tensor(keyboard_condition * num_frames, dtype=self.weight_dtype).unsqueeze(0).to(self.device) mouse_tensor = torch.tensor(mouse_condition * num_frames, dtype=self.weight_dtype).unsqueeze(0).to(self.device) # Build conditional dict cond_concat = torch.cat([scene_data['mask'][:, :4], scene_data['latent']], dim=1) conditional_dict = { "cond_concat": cond_concat.to(device=self.device, dtype=self.weight_dtype), "visual_context": scene_data['visual_context'].to(device=self.device, dtype=self.weight_dtype), "keyboard_cond": keyboard_tensor } # Add mouse condition for modes that support it if mode in ['universal', 'gta_drive']: conditional_dict['mouse_cond'] = mouse_tensor # Generate noise for the frames sampled_noise = torch.randn( [1, 16, num_frames, 44, 80], device=self.device, dtype=self.weight_dtype ) # Generate frames with streaming pipeline with torch.no_grad(): # Set seed for reproducibility set_seed(self.seed + self.frame_count) # Use inference method for single batch generation outputs = self.pipeline.inference( noise=sampled_noise, conditional_dict=conditional_dict, return_latents=True, # Return latents for faster decoding output_folder=None, # Don't save to disk name=None, mode=mode ) # Decode first frame from latent if outputs is not None and len(outputs) > 0: # Extract first frame frame_latent = outputs[0:1, :, 0:1] # Get first frame decoded = self.pipeline.vae_decoder.decode(frame_latent) # Convert to numpy frame = decoded[0, :, 0].permute(1, 2, 0).cpu().numpy() frame = ((frame + 1) * 127.5).clip(0, 255).astype(np.uint8) else: # Generation failed error_msg = "Failed to generate frame: No output from model" logger.error(error_msg) raise RuntimeError(error_msg) self.frame_count += 1 except Exception as e: error_msg = f"Error generating frame with Matrix-Game V2 model: {str(e)}" logger.error(error_msg) raise RuntimeError(error_msg) # Add visualization of input controls frame = visualize_controls( frame, keyboard_condition, mouse_condition, self.frame_width, self.frame_height ) # Convert frame to JPEG return frame_to_jpeg(frame, self.frame_height, self.frame_width) def get_valid_scenes(self) -> List[str]: """ Get a list of valid scene names. Returns: List[str]: List of valid scene names """ return list(self.scenes.keys())