Matrix-Game-2 / app.py
laloadrianmorales's picture
Update app.py
d9237e4 verified
raw
history blame
9.23 kB
import os
import gradio as gr
import torch
import spaces
from PIL import Image
import tempfile
import subprocess
import sys
from huggingface_hub import snapshot_download, hf_hub_download
import shutil
# Configuration
MODEL_REPO = "Skywork/Matrix-Game-2.0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"๐Ÿš€ Matrix-Game-2.0 Streamlined")
print(f"๐Ÿ“ฑ Device: {DEVICE}")
print(f"๐Ÿ”ฅ CUDA Available: {torch.cuda.is_available()}")
# Global variables for model loading
model_loaded = False
model_path = None
def download_and_setup_model():
"""Download model and setup environment - run once"""
global model_loaded, model_path
if model_loaded:
return True
try:
print("๐Ÿ“ฅ Downloading Matrix-Game-2.0 model...")
# Download the model to cache
model_path = snapshot_download(
repo_id=MODEL_REPO,
cache_dir="./model_cache",
allow_patterns=["*.safetensors", "*.bin", "*.json", "*.yaml", "*.yml", "*.py"],
)
print(f"โœ… Model downloaded to: {model_path}")
# Clone the inference code repository
if not os.path.exists("Matrix-Game"):
print("๐Ÿ“ฅ Cloning Matrix-Game repository...")
result = subprocess.run([
'git', 'clone', 'https://github.com/SkyworkAI/Matrix-Game.git'
], capture_output=True, text=True, timeout=180)
if result.returncode != 0:
print(f"โŒ Git clone failed: {result.stderr}")
return False
# Setup Python path to include Matrix-Game modules
matrix_game_path = os.path.join(os.getcwd(), "Matrix-Game", "Matrix-Game-2")
if matrix_game_path not in sys.path:
sys.path.insert(0, matrix_game_path)
model_loaded = True
return True
except Exception as e:
print(f"โŒ Setup failed: {e}")
return False
@spaces.GPU(duration=120) # Allocate GPU for 2 minutes max
def generate_video(input_image, num_frames, seed, progress=gr.Progress()):
"""Generate video using Matrix-Game-2.0"""
if input_image is None:
return None, "โŒ Please upload an input image first"
# Setup model if not already done
progress(0.1, desc="๐Ÿ”ง Setting up model...")
if not download_and_setup_model():
return None, "โŒ Failed to setup model"
progress(0.2, desc="๐Ÿ“ท Processing input image...")
try:
# Create temporary directories
temp_dir = tempfile.mkdtemp(prefix="matrix_gen_")
output_dir = os.path.join(temp_dir, "outputs")
os.makedirs(output_dir, exist_ok=True)
# Prepare input image
if max(input_image.size) > 512: # Resize for faster processing
ratio = 512 / max(input_image.size)
new_size = (int(input_image.size[0] * ratio), int(input_image.size[1] * ratio))
input_image = input_image.resize(new_size, Image.Resampling.LANCZOS)
input_path = os.path.join(temp_dir, "input.jpg")
input_image.save(input_path, "JPEG", quality=95)
progress(0.4, desc="๐Ÿš€ Generating video...")
# Find the inference script and config
matrix_dir = os.path.join("Matrix-Game", "Matrix-Game-2")
# Basic inference command (simplified)
cmd = [
sys.executable,
os.path.join(matrix_dir, "inference.py"),
"--img_path", input_path,
"--output_folder", output_dir,
"--num_output_frames", str(min(num_frames, 100)), # Limit frames for HF Spaces
"--seed", str(seed)
]
# Add model and config paths if found
config_files = []
for root, dirs, files in os.walk(matrix_dir):
for file in files:
if file.endswith(('.yaml', '.yml')) and 'config' in file.lower():
config_files.append(os.path.join(root, file))
if config_files:
cmd.extend(["--config_path", config_files[0]])
if model_path:
cmd.extend(["--pretrained_model_path", model_path])
progress(0.6, desc="๐ŸŽฌ Running inference...")
# Execute with timeout
process = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
cwd=matrix_dir
)
progress(0.9, desc="๐Ÿ“น Finalizing video...")
# Find output video
video_files = []
for root, dirs, files in os.walk(output_dir):
for file in files:
if file.lower().endswith(('.mp4', '.avi', '.mov', '.gif')):
video_files.append(os.path.join(root, file))
if video_files:
# Copy to a permanent location
final_output = f"output_{seed}.mp4"
shutil.copy(video_files[0], final_output)
log = f"""
โœ… **Generation Successful!**
๐Ÿ“Š Input: {input_image.size}
๐ŸŽฌ Frames: {num_frames}
๐ŸŽฒ Seed: {seed}
๐Ÿ“ Output: {final_output}
"""
progress(1.0, desc="โœ… Complete!")
return final_output, log
else:
error_log = f"""
โŒ **Generation Failed**
๐Ÿ“ Error output: {process.stderr[:500] if process.stderr else 'No error details'}
๐Ÿ’ญ Try adjusting parameters or using a different input image
"""
return None, error_log
except subprocess.TimeoutExpired:
return None, "โŒ Generation timed out (>5 minutes). Try fewer frames."
except Exception as e:
return None, f"โŒ Error during generation: {str(e)}"
finally:
# Cleanup
if 'temp_dir' in locals() and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
# Gradio Interface
def create_interface():
with gr.Blocks(
title="Matrix-Game-2.0",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
margin: auto !important;
}
"""
) as interface:
gr.HTML("""
<div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 15px; margin-bottom: 20px;">
<h1>๐ŸŽฎ Matrix-Game-2.0</h1>
<p style="font-size: 18px;">Interactive World Model for Real-Time Video Generation</p>
<p style="opacity: 0.8;">Upload an image and generate interactive video content!</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“ธ Input")
input_image = gr.Image(
label="Input Image",
type="pil",
height=300
)
gr.Markdown("### โš™๏ธ Settings")
with gr.Row():
num_frames = gr.Slider(
minimum=25,
maximum=100,
value=50,
step=25,
label="Number of Frames"
)
seed = gr.Number(
value=42,
label="Seed",
precision=0
)
generate_btn = gr.Button(
"๐Ÿš€ Generate Video",
variant="primary",
size="lg"
)
gr.Markdown("""
### ๐Ÿ’ก Tips
- Use clear, well-lit images
- Landscapes and scenes work best
- Lower frame counts = faster generation
- Try different seeds for variety
""")
with gr.Column(scale=1):
gr.Markdown("### ๐ŸŽฌ Generated Video")
output_video = gr.Video(
label="Result",
height=400
)
status_log = gr.Textbox(
label="Status Log",
lines=8,
max_lines=10
)
# Event handlers
generate_btn.click(
fn=generate_video,
inputs=[input_image, num_frames, seed],
outputs=[output_video, status_log]
)
# Example inputs
gr.Examples(
examples=[
["https://images.unsplash.com/photo-1506905925346-21bda4d32df4", 50, 42],
["https://images.unsplash.com/photo-1441974231531-c6227db76b6e", 75, 123],
],
inputs=[input_image, num_frames, seed],
label="Example Images"
)
return interface
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)