Spaces:
Running
on
A100
Running
on
A100

jbilcke-hf
HF Staff
Merge branch 'main' of hf.co:spaces/jbilcke-hf/Hunyuan-GameCraft into zerogpu
60d2ea4
import os | |
# Get weights path from environment variable or use default | |
WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "/data/weights") | |
# IMPORTANT: Set environment variables BEFORE importing any modules that use them | |
os.environ["MODEL_BASE"] = os.path.join(WEIGHTS_PATH, "stdmodels") | |
os.environ["DISABLE_SP"] = "1" | |
# Configure CPU_OFFLOAD in system environment variables: | |
# Set CPU_OFFLOAD=1 to enable CPU offloading (for low VRAM, but slower) | |
# Set CPU_OFFLOAD=0 to disable CPU offloading (requires more VRAM, but faster) | |
# os.environ["CPU_OFFLOAD"] = "1" | |
import torch | |
import gradio as gr | |
import numpy as np | |
import random | |
from pathlib import Path | |
from PIL import Image | |
import torchvision.transforms as transforms | |
from loguru import logger | |
from huggingface_hub import hf_hub_download | |
import tempfile | |
import spaces | |
from hymm_sp.sample_inference import HunyuanVideoSampler | |
from hymm_sp.data_kits.data_tools import save_videos_grid | |
from hymm_sp.config import parse_args | |
import argparse | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def detect_gpu_supports_fp8(): | |
"""Detect if the current GPU supports FP8 operations.""" | |
if not torch.cuda.is_available(): | |
return False | |
try: | |
# Get compute capability | |
compute_capability = torch.cuda.get_device_capability() | |
major, minor = compute_capability | |
# Get GPU name for logging | |
gpu_name = torch.cuda.get_device_name() | |
# FP8 with fp8e4m3fn (fp8e4nv) requires compute capability >= 9.0 (H100, H200) | |
# A100 has compute capability 8.0 and doesn't support this FP8 variant | |
supports_fp8 = major >= 9 | |
logger.info(f"GPU detected: {gpu_name} (compute capability {major}.{minor})") | |
logger.info(f"FP8 support: {'Enabled' if supports_fp8 else 'Disabled (requires compute capability >= 9.0)'}") | |
return supports_fp8 | |
except Exception as e: | |
logger.warning(f"Could not detect GPU capabilities: {e}. Disabling FP8.") | |
return False | |
class CropResize: | |
def __init__(self, size=(704, 1216)): | |
self.target_h, self.target_w = size | |
def __call__(self, img): | |
w, h = img.size | |
scale = max( | |
self.target_w / w, | |
self.target_h / h | |
) | |
new_size = (int(h * scale), int(w * scale)) | |
resize_transform = transforms.Resize( | |
new_size, | |
interpolation=transforms.InterpolationMode.BILINEAR | |
) | |
resized_img = resize_transform(img) | |
crop_transform = transforms.CenterCrop((self.target_h, self.target_w)) | |
return crop_transform(resized_img) | |
def create_args(): | |
args = argparse.Namespace() | |
args.ckpt = os.path.join(WEIGHTS_PATH, "gamecraft_models/mp_rank_00_model_states_distill.pt") | |
args.video_size = [704, 1216] | |
args.cfg_scale = 1.0 | |
args.image_start = True | |
args.seed = None | |
args.infer_steps = 8 | |
args.use_fp8 = detect_gpu_supports_fp8() # Auto-detect FP8 support based on GPU | |
args.flow_shift_eval_video = 5.0 | |
args.sample_n_frames = 33 | |
args.num_images = 1 | |
args.use_linear_quadratic_schedule = False | |
args.linear_schedule_end = 0.25 | |
args.use_deepcache = False | |
args.cpu_offload = False # Always False for ZeroGPU compatibility | |
args.use_sage = True | |
args.save_path = './results/' | |
args.save_path_suffix = '' | |
args.add_pos_prompt = "Realistic, High-quality." | |
args.add_neg_prompt = "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." | |
args.model = "HYVideo-T/2" | |
args.precision = "bf16" | |
args.vae = "884-16c-hy0801" | |
args.vae_precision = "fp16" | |
args.text_encoder = "llava-llama-3-8b" | |
args.text_encoder_precision = "fp16" | |
args.text_encoder_precision_2 = "fp16" | |
args.tokenizer = "llava-llama-3-8b" | |
args.text_encoder_2 = "clipL" | |
args.tokenizer_2 = "clipL" | |
args.latent_channels = 16 | |
args.text_len = 256 | |
args.text_len_2 = 77 | |
args.use_attention_mask = True | |
args.hidden_state_skip_layer = 2 | |
args.apply_final_norm = False | |
args.prompt_template_video = "li-dit-encode-video" | |
args.reproduce = False | |
args.load_key = "module" | |
# text encoder related attributes | |
args.text_projection = "single_refiner" | |
args.text_states_dim = 4096 | |
args.text_states_dim_2 = 768 | |
# default is True based on config.py | |
args.flow_reverse = True | |
# default is "euler" based on config.py | |
args.flow_solver = "euler" | |
# default is 256 based on config.py | |
args.rope_theta = 256 | |
# default for HYVideo-T/2 model | |
args.patch_size = [1, 2, 2] | |
# default is True based on config.py | |
args.vae_tiling = True | |
# default is 0 based on config.py | |
args.ip_cfg_scale = 0.0 | |
# val_disable_autocast is needed by the pipeline | |
args.val_disable_autocast = False | |
return args | |
# Define all required model files | |
required_files = [ | |
"gamecraft_models/mp_rank_00_model_states_distill.pt", | |
"stdmodels/vae_3d/hyvae/config.json", | |
"stdmodels/vae_3d/hyvae/pytorch_model.pt", | |
] | |
# Check and download missing files | |
for file_path in required_files: | |
full_path = os.path.join(WEIGHTS_PATH, file_path) | |
if not os.path.exists(full_path): | |
logger.info(f"Downloading {file_path} from Hugging Face...") | |
os.makedirs(os.path.dirname(full_path), exist_ok=True) | |
try: | |
hf_hub_download( | |
repo_id="tencent/Hunyuan-GameCraft-1.0", | |
filename=file_path, | |
local_dir=WEIGHTS_PATH, | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"Successfully downloaded {file_path}") | |
except Exception as e: | |
logger.error(f"Failed to download {file_path}: {e}") | |
raise | |
# Also check for text encoder files (download if needed) | |
text_encoder_files = [ | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/model-00001-of-00004.safetensors", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/model-00002-of-00004.safetensors", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/model-00003-of-00004.safetensors", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/model-00004-of-00004.safetensors", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/model.safetensors.index.json", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/config.json", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/tokenizer.json", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/tokenizer_config.json", | |
"stdmodels/llava-llama-3-8b-v1_1-transformers/special_tokens_map.json", | |
"stdmodels/openai_clip-vit-large-patch14/config.json", | |
"stdmodels/openai_clip-vit-large-patch14/pytorch_model.bin", | |
"stdmodels/openai_clip-vit-large-patch14/tokenizer.json", | |
"stdmodels/openai_clip-vit-large-patch14/tokenizer_config.json", | |
"stdmodels/openai_clip-vit-large-patch14/special_tokens_map.json", | |
"stdmodels/openai_clip-vit-large-patch14/vocab.json", | |
"stdmodels/openai_clip-vit-large-patch14/merges.txt", | |
] | |
for file_path in text_encoder_files: | |
full_path = os.path.join(WEIGHTS_PATH, file_path) | |
if not os.path.exists(full_path): | |
logger.info(f"Downloading {file_path} from Hugging Face...") | |
os.makedirs(os.path.dirname(full_path), exist_ok=True) | |
try: | |
hf_hub_download( | |
repo_id="tencent/Hunyuan-GameCraft-1.0", | |
filename=file_path, | |
local_dir=WEIGHTS_PATH, | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"Successfully downloaded {file_path}") | |
except Exception as e: | |
logger.error(f"Failed to download {file_path}: {e}") | |
# Continue anyway as some files might be optional | |
logger.info("All required model files are ready") | |
logger.info("Initializing Hunyuan-GameCraft model...") | |
args = create_args() | |
logger.info(f"Created args, val_disable_autocast: {hasattr(args, 'val_disable_autocast')} = {getattr(args, 'val_disable_autocast', 'NOT SET')}") | |
# For ZeroGPU, always load model to CPU initially (it will be moved to GPU during inference) | |
model_device = torch.device("cpu") | |
logger.info(f"Loading model to device: {model_device}") | |
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained( | |
args.ckpt, | |
args=args, | |
device=model_device | |
) | |
logger.info(f"After from_pretrained, sampler.args has val_disable_autocast: {hasattr(hunyuan_video_sampler.args, 'val_disable_autocast')} = {getattr(hunyuan_video_sampler.args, 'val_disable_autocast', 'NOT SET')}") | |
args = hunyuan_video_sampler.args | |
logger.info(f"After reassigning args, val_disable_autocast: {hasattr(args, 'val_disable_autocast')} = {getattr(args, 'val_disable_autocast', 'NOT SET')}") | |
# Don't apply CPU offloading for ZeroGPU - the model stays on CPU until needed | |
logger.info("Model loaded successfully on CPU, will be moved to GPU during inference") | |
def generate_video( | |
input_image, | |
prompt, | |
action_sequence, | |
action_speeds, | |
negative_prompt, | |
seed, | |
cfg_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
try: | |
progress(0, desc="Initializing...") | |
if input_image is None: | |
return None, "Please upload an image first!" | |
# Move model components to GPU for ZeroGPU inference | |
logger.info("Moving model components to GPU...") | |
hunyuan_video_sampler.pipeline.transformer.to('cuda') | |
hunyuan_video_sampler.vae.to('cuda') | |
if hunyuan_video_sampler.text_encoder: | |
hunyuan_video_sampler.text_encoder.model.to('cuda') | |
if hunyuan_video_sampler.text_encoder_2: | |
hunyuan_video_sampler.text_encoder_2.model.to('cuda') | |
logger.info("Model components moved to GPU") | |
action_list = action_sequence.lower().replace(" ", "").split(",") if action_sequence else ["w"] | |
speed_list = [float(s.strip()) for s in action_speeds.split(",")] if action_speeds else [0.2] | |
if len(speed_list) != len(action_list): | |
if len(speed_list) == 1: | |
speed_list = speed_list * len(action_list) | |
else: | |
return None, f"Error: Number of speeds ({len(speed_list)}) must match number of actions ({len(action_list)})" | |
for action in action_list: | |
if action not in ['w', 'a', 's', 'd']: | |
return None, f"Error: Invalid action '{action}'. Use only w, a, s, d" | |
for speed in speed_list: | |
if not 0.0 <= speed <= 3.0: | |
return None, f"Error: Speed {speed} out of range. Use values between 0.0 and 3.0" | |
progress(0.1, desc="Processing image...") | |
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
input_image.save(tmp_file.name) | |
image_path = tmp_file.name | |
closest_size = (704, 1216) | |
ref_image_transform = transforms.Compose([ | |
CropResize(closest_size), | |
transforms.CenterCrop(closest_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) | |
]) | |
raw_ref_image = Image.open(image_path).convert('RGB') | |
ref_image_pixel_values = ref_image_transform(raw_ref_image) | |
ref_image_pixel_values = ref_image_pixel_values.unsqueeze(0).unsqueeze(2).to(device) | |
progress(0.2, desc="Encoding image...") | |
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): | |
hunyuan_video_sampler.pipeline.vae.enable_tiling() | |
raw_last_latents = hunyuan_video_sampler.vae.encode( | |
ref_image_pixel_values | |
).latent_dist.sample().to(dtype=torch.float16) | |
raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor) | |
raw_ref_latents = raw_last_latents.clone() | |
hunyuan_video_sampler.pipeline.vae.disable_tiling() | |
ref_images = [raw_ref_image] | |
last_latents = raw_last_latents | |
ref_latents = raw_ref_latents | |
progress(0.3, desc="Starting video generation...") | |
if seed is None or seed == -1: | |
seed = random.randint(0, 1_000_000) | |
all_samples = [] | |
for idx, (action_id, action_speed) in enumerate(zip(action_list, speed_list)): | |
is_image = (idx == 0) | |
progress(0.3 + (0.6 * idx / len(action_list)), | |
desc=f"Generating segment {idx+1}/{len(action_list)} (action: {action_id})") | |
logger.info(f"Before predict call {idx}, args has val_disable_autocast: {hasattr(args, 'val_disable_autocast')} = {getattr(args, 'val_disable_autocast', 'NOT SET')}") | |
logger.info(f"hunyuan_video_sampler.args has val_disable_autocast: {hasattr(hunyuan_video_sampler.args, 'val_disable_autocast')} = {getattr(hunyuan_video_sampler.args, 'val_disable_autocast', 'NOT SET')}") | |
outputs = hunyuan_video_sampler.predict( | |
prompt=prompt, | |
action_id=action_id, | |
action_speed=action_speed, | |
is_image=is_image, | |
size=(704, 1216), | |
seed=seed, | |
last_latents=last_latents, | |
ref_latents=ref_latents, | |
video_length=args.sample_n_frames, | |
guidance_scale=cfg_scale, | |
num_images_per_prompt=1, | |
negative_prompt=negative_prompt, | |
infer_steps=num_inference_steps, | |
flow_shift=args.flow_shift_eval_video, | |
use_linear_quadratic_schedule=args.use_linear_quadratic_schedule, | |
linear_schedule_end=args.linear_schedule_end, | |
use_deepcache=args.use_deepcache, | |
cpu_offload=False, # Always False for ZeroGPU | |
ref_images=ref_images, | |
output_dir=None, | |
return_latents=True, | |
use_sage=args.use_sage, | |
) | |
ref_latents = outputs["ref_latents"] | |
last_latents = outputs["last_latents"] | |
sub_samples = outputs['samples'][0] | |
all_samples.append(sub_samples) | |
progress(0.9, desc="Finalizing video...") | |
if len(all_samples) > 0: | |
out_cat = torch.cat(all_samples, dim=2) | |
else: | |
out_cat = all_samples[0] | |
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video: | |
output_path = tmp_video.name | |
save_videos_grid(out_cat, output_path, n_rows=1, fps=25) | |
if os.path.exists(image_path): | |
os.remove(image_path) | |
progress(1.0, desc="Complete!") | |
return output_path, "Video generated successfully!" | |
except Exception as e: | |
logger.error(f"Error generating video: {e}") | |
return None, f"Error: {str(e)}" | |
with gr.Blocks(title="Hunyuan-GameCraft") as demo: | |
gr.Markdown(""" | |
# 🎮 Hunyuan-GameCraft Video Generation | |
Generate interactive game-style videos from a single image using keyboard actions (W/A/S/D). | |
Using the **distilled model** for faster generation (8 inference steps). | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image( | |
label="Input Image", | |
type="pil", | |
height=400 | |
) | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the scene...", | |
value="A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky.", | |
lines=3 | |
) | |
with gr.Accordion("Action Controls", open=True): | |
action_sequence = gr.Textbox( | |
label="Action Sequence (comma-separated)", | |
placeholder="w, a, s, d", | |
value="w, s, d, a", | |
info="Use w (forward), a (left), s (backward), d (right)" | |
) | |
action_speeds = gr.Textbox( | |
label="Action Speeds (comma-separated)", | |
placeholder="0.2, 0.2, 0.2, 0.2", | |
value="0.2, 0.2, 0.2, 0.2", | |
info="Speed for each action (0.0 to 3.0). Single value applies to all." | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border.", | |
lines=2 | |
) | |
seed = gr.Number( | |
label="Seed", | |
value=-1, | |
precision=0, | |
info="Set to -1 for random seed" | |
) | |
cfg_scale = gr.Slider( | |
label="CFG Scale", | |
minimum=0.5, | |
maximum=3.0, | |
value=1.0, | |
step=0.1, | |
info="Classifier-free guidance scale (1.0 for distilled model)" | |
) | |
num_inference_steps = gr.Slider( | |
label="Inference Steps", | |
minimum=4, | |
maximum=20, | |
value=8, | |
step=1, | |
info="Number of denoising steps (8 for distilled model)" | |
) | |
generate_btn = gr.Button("Generate Video", variant="primary") | |
with gr.Column(scale=1): | |
output_video = gr.Video( | |
label="Generated Video", | |
height=400 | |
) | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
gr.Markdown(""" | |
### Tips: | |
- Each action generates 33 frames (1.3 seconds at 25 FPS) | |
- The distilled model is optimized for speed with 8 inference steps | |
- Use FP8 optimization for better memory efficiency | |
- Minimum GPU memory: 24GB VRAM | |
""") | |
generate_btn.click( | |
fn=generate_video, | |
inputs=[ | |
input_image, | |
prompt, | |
action_sequence, | |
action_speeds, | |
negative_prompt, | |
seed, | |
cfg_scale, | |
num_inference_steps | |
], | |
outputs=[output_video, status_text] | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
"asset/village.png", | |
"A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky.", | |
"w, a, d, s", | |
"0.2, 0.2, 0.2, 0.2" | |
] | |
], | |
inputs=[input_image, prompt, action_sequence, action_speeds], | |
label="Example" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |