Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from PIL import Image | |
| import tempfile | |
| import subprocess | |
| import sys | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import shutil | |
| # Configuration | |
| MODEL_REPO = "Skywork/Matrix-Game-2.0" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"๐ Matrix-Game-2.0 Streamlined") | |
| print(f"๐ฑ Device: {DEVICE}") | |
| print(f"๐ฅ CUDA Available: {torch.cuda.is_available()}") | |
| # Global variables for model loading | |
| model_loaded = False | |
| model_path = None | |
| def download_and_setup_model(): | |
| """Download model and setup environment - run once""" | |
| global model_loaded, model_path | |
| if model_loaded: | |
| return True | |
| try: | |
| print("๐ฅ Downloading Matrix-Game-2.0 model...") | |
| # Download the model to cache | |
| model_path = snapshot_download( | |
| repo_id=MODEL_REPO, | |
| cache_dir="./model_cache", | |
| allow_patterns=["*.safetensors", "*.bin", "*.json", "*.yaml", "*.yml", "*.py"], | |
| ) | |
| print(f"โ Model downloaded to: {model_path}") | |
| # Clone the inference code repository | |
| if not os.path.exists("Matrix-Game"): | |
| print("๐ฅ Cloning Matrix-Game repository...") | |
| result = subprocess.run([ | |
| 'git', 'clone', 'https://github.com/SkyworkAI/Matrix-Game.git' | |
| ], capture_output=True, text=True, timeout=180) | |
| if result.returncode != 0: | |
| print(f"โ Git clone failed: {result.stderr}") | |
| return False | |
| # Setup Python path to include Matrix-Game modules | |
| matrix_game_path = os.path.join(os.getcwd(), "Matrix-Game", "Matrix-Game-2") | |
| if matrix_game_path not in sys.path: | |
| sys.path.insert(0, matrix_game_path) | |
| model_loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"โ Setup failed: {e}") | |
| return False | |
| # Allocate GPU for 2 minutes max | |
| def generate_video(input_image, num_frames, seed, progress=gr.Progress()): | |
| """Generate video using Matrix-Game-2.0""" | |
| if input_image is None: | |
| return None, "โ Please upload an input image first" | |
| # Setup model if not already done | |
| progress(0.1, desc="๐ง Setting up model...") | |
| if not download_and_setup_model(): | |
| return None, "โ Failed to setup model" | |
| progress(0.2, desc="๐ท Processing input image...") | |
| try: | |
| # Create temporary directories | |
| temp_dir = tempfile.mkdtemp(prefix="matrix_gen_") | |
| output_dir = os.path.join(temp_dir, "outputs") | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Prepare input image | |
| if max(input_image.size) > 512: # Resize for faster processing | |
| ratio = 512 / max(input_image.size) | |
| new_size = (int(input_image.size[0] * ratio), int(input_image.size[1] * ratio)) | |
| input_image = input_image.resize(new_size, Image.Resampling.LANCZOS) | |
| input_path = os.path.join(temp_dir, "input.jpg") | |
| input_image.save(input_path, "JPEG", quality=95) | |
| progress(0.4, desc="๐ Generating video...") | |
| # Find the inference script and config | |
| matrix_dir = os.path.join("Matrix-Game", "Matrix-Game-2") | |
| # Basic inference command (simplified) | |
| cmd = [ | |
| sys.executable, | |
| os.path.join(matrix_dir, "inference.py"), | |
| "--img_path", input_path, | |
| "--output_folder", output_dir, | |
| "--num_output_frames", str(min(num_frames, 100)), # Limit frames for HF Spaces | |
| "--seed", str(seed) | |
| ] | |
| # Add model and config paths if found | |
| config_files = [] | |
| for root, dirs, files in os.walk(matrix_dir): | |
| for file in files: | |
| if file.endswith(('.yaml', '.yml')) and 'config' in file.lower(): | |
| config_files.append(os.path.join(root, file)) | |
| if config_files: | |
| cmd.extend(["--config_path", config_files[0]]) | |
| if model_path: | |
| cmd.extend(["--pretrained_model_path", model_path]) | |
| progress(0.6, desc="๐ฌ Running inference...") | |
| # Execute with timeout | |
| process = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| timeout=300, # 5 minute timeout | |
| cwd=matrix_dir | |
| ) | |
| progress(0.9, desc="๐น Finalizing video...") | |
| # Find output video | |
| video_files = [] | |
| for root, dirs, files in os.walk(output_dir): | |
| for file in files: | |
| if file.lower().endswith(('.mp4', '.avi', '.mov', '.gif')): | |
| video_files.append(os.path.join(root, file)) | |
| if video_files: | |
| # Copy to a permanent location | |
| final_output = f"output_{seed}.mp4" | |
| shutil.copy(video_files[0], final_output) | |
| log = f""" | |
| โ **Generation Successful!** | |
| ๐ Input: {input_image.size} | |
| ๐ฌ Frames: {num_frames} | |
| ๐ฒ Seed: {seed} | |
| ๐ Output: {final_output} | |
| """ | |
| progress(1.0, desc="โ Complete!") | |
| return final_output, log | |
| else: | |
| error_log = f""" | |
| โ **Generation Failed** | |
| ๐ Error output: {process.stderr[:500] if process.stderr else 'No error details'} | |
| ๐ญ Try adjusting parameters or using a different input image | |
| """ | |
| return None, error_log | |
| except subprocess.TimeoutExpired: | |
| return None, "โ Generation timed out (>5 minutes). Try fewer frames." | |
| except Exception as e: | |
| return None, f"โ Error during generation: {str(e)}" | |
| finally: | |
| # Cleanup | |
| if 'temp_dir' in locals() and os.path.exists(temp_dir): | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| # Gradio Interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="Matrix-Game-2.0", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| """ | |
| ) as interface: | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 20px;"> | |
| <h1>๐ฎ Matrix-Game-2.0</h1> | |
| <p style="font-size: 18px;">Interactive World Model for Real-Time Video Generation</p> | |
| <p style="opacity: 0.8;">Upload an image and generate interactive video content!</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ๐ธ Input") | |
| input_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=300 | |
| ) | |
| gr.Markdown("### โ๏ธ Settings") | |
| with gr.Row(): | |
| num_frames = gr.Slider( | |
| minimum=25, | |
| maximum=100, | |
| value=50, | |
| step=25, | |
| label="Number of Frames" | |
| ) | |
| seed = gr.Number( | |
| value=42, | |
| label="Seed", | |
| precision=0 | |
| ) | |
| generate_btn = gr.Button( | |
| "๐ Generate Video", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown(""" | |
| ### ๐ก Tips | |
| - Use clear, well-lit images | |
| - Landscapes and scenes work best | |
| - Lower frame counts = faster generation | |
| - Try different seeds for variety | |
| """) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ๐ฌ Generated Video") | |
| output_video = gr.Video( | |
| label="Result", | |
| height=400 | |
| ) | |
| status_log = gr.Textbox( | |
| label="Status Log", | |
| lines=8, | |
| max_lines=10 | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_video, | |
| inputs=[input_image, num_frames, seed], | |
| outputs=[output_video, status_log] | |
| ) | |
| # Example inputs | |
| gr.Examples( | |
| examples=[ | |
| ["https://images.unsplash.com/photo-1506905925346-21bda4d32df4", 50, 42], | |
| ["https://images.unsplash.com/photo-1441974231531-c6227db76b6e", 75, 123], | |
| ], | |
| inputs=[input_image, num_frames, seed], | |
| label="Example Images" | |
| ) | |
| return interface | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |