Spaces:
Paused
Paused
| # app.py | |
| import spaces | |
| import gradio as gr | |
| import argparse | |
| import sys | |
| import os | |
| import random | |
| import subprocess | |
| from PIL import Image | |
| import numpy as np | |
| subprocess.run(['sh', './sky.sh']) | |
| sys.path.append("./SkyReels-V1") | |
| from skyreelsinfer import TaskType | |
| from skyreelsinfer.offload import OffloadConfig | |
| from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer | |
| from diffusers.utils import export_to_video | |
| import torch | |
| import logging | |
| from collections import OrderedDict # Import OrderedDict here | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
| torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = False | |
| torch.set_float32_matmul_precision("highest") | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| logger = logging.getLogger(__name__) | |
| # --- Dummy Classes (Keep for standalone execution) --- | |
| class OffloadConfig: | |
| def __init__(self, high_cpu_memory=False, parameters_level=False, compiler_transformer=False, compiler_cache=""): | |
| self.high_cpu_memory = high_cpu_memory | |
| self.parameters_level = parameters_level | |
| self.compiler_transformer = compiler_transformer | |
| self.compiler_cache = compiler_cache | |
| class TaskType: #Keep here for infer | |
| T2V = 0 | |
| I2V = 1 | |
| class LlamaModel: | |
| def from_pretrained(*args, **kwargs): | |
| return LlamaModel() | |
| def to(self, device): | |
| return self | |
| class HunyuanVideoTransformer3DModel: | |
| def from_pretrained(*args, **kwargs): | |
| return HunyuanVideoTransformer3DModel() | |
| def to(self, device): | |
| return self | |
| class SkyreelsVideoPipeline: | |
| def from_pretrained(*args, **kwargs): | |
| return SkyreelsVideoPipeline() | |
| def to(self, device): | |
| return self | |
| def __call__(self, *args, **kwargs): | |
| num_frames = kwargs.get("num_frames", 16) # Default to 16 frames | |
| height = kwargs.get("height", 512) | |
| width = kwargs.get("width", 512) | |
| if "image" in kwargs: # I2V | |
| image = kwargs["image"] | |
| # Convert PIL Image to PyTorch tensor (and normalize to [0, 1]) | |
| image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 | |
| image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) # (H, W, C) -> (1, C, H, W) | |
| # Create video by repeating the image and adding noise | |
| frames = image_tensor.repeat(1, 1, num_frames, 1, 1) # (1, C, T, H, W) | |
| frames = frames + torch.randn_like(frames) * 0.05 # Add a little noise. | |
| else: # T2V | |
| frames = torch.randn(1, 3, num_frames, height, width) # Use correct dims | |
| return type('obj', (object,), {'frames' : frames})() # No longer a list! | |
| def __init__(self): | |
| super().__init__() | |
| self._modules = OrderedDict() | |
| self.vae = self.VAE() | |
| self._modules["vae"] = self.vae | |
| def named_children(self): | |
| return self._modules.items() | |
| class VAE: | |
| def enable_tiling(self): | |
| pass | |
| def quantize_(*args, **kwargs): | |
| return | |
| def float8_weight_only(): | |
| return | |
| # --- End Dummy Classes --- | |
| class SkyReelsVideoSingleGpuInfer: | |
| def _load_model(self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True): | |
| logger.info(f"load model model_id:{model_id} quan_model:{quant_model}") | |
| text_encoder = LlamaModel.from_pretrained( | |
| base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16 | |
| ).to("cpu") | |
| transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
| model_id, torch_dtype=torch.bfloat16, device="cpu" | |
| ).to("cpu") | |
| if quant_model: | |
| quantize_(text_encoder, float8_weight_only()) | |
| text_encoder.to("cpu") | |
| torch.cuda.empty_cache() | |
| quantize_(transformer, float8_weight_only()) | |
| transformer.to("cpu") | |
| torch.cuda.empty_cache() | |
| pipe = SkyreelsVideoPipeline.from_pretrained( | |
| base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16 | |
| ).to("cpu") | |
| pipe.vae.enable_tiling() | |
| torch.cuda.empty_cache() | |
| return pipe | |
| def __init__( | |
| self, | |
| task_type: TaskType, | |
| model_id: str, | |
| quant_model: bool = True, | |
| is_offload: bool = True, | |
| offload_config: OffloadConfig = OffloadConfig(), | |
| enable_cfg_parallel: bool = True, | |
| ): | |
| self.task_type = task_type | |
| self.model_id = model_id | |
| self.quant_model = quant_model | |
| self.is_offload = is_offload | |
| self.offload_config = offload_config | |
| self.enable_cfg_parallel = enable_cfg_parallel | |
| self.pipe = None | |
| self.is_initialized = False | |
| self.gpu_device = None | |
| def initialize(self): | |
| """Initializes the model and moves it to the GPU.""" | |
| if self.is_initialized: | |
| return | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is not available. Cannot initialize model.") | |
| self.gpu_device = "cuda:0" | |
| self.pipe = self._load_model(model_id=self.model_id, quant_model=self.quant_model) | |
| if self.is_offload: | |
| pass | |
| else: | |
| self.pipe.to(self.gpu_device) | |
| if self.offload_config.compiler_transformer: | |
| torch._dynamo.config.suppress_errors = True | |
| os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{self.offload_config.compiler_cache}" | |
| self.pipe.transformer = torch.compile( | |
| self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True | |
| ) | |
| if self.offload_config.compiler_transformer: | |
| self.warm_up() | |
| self.is_initialized = True | |
| def warm_up(self): | |
| if not self.is_initialized: | |
| raise RuntimeError("Model must be initialized before warm-up.") | |
| init_kwargs = { | |
| "prompt": "A woman is dancing in a room", | |
| "height": 544, | |
| "width": 960, | |
| "guidance_scale": 6, | |
| "num_inference_steps": 1, | |
| "negative_prompt": "bad quality", | |
| "num_frames": 16, | |
| "generator": torch.Generator(self.gpu_device).manual_seed(42), | |
| "embedded_guidance_scale": 1.0, | |
| } | |
| if self.task_type == TaskType.I2V: | |
| init_kwargs["image"] = Image.new("RGB",(544,960), color="black") | |
| self.pipe(**init_kwargs) | |
| logger.info("Warm-up complete.") | |
| def infer(self, **kwargs): | |
| """Handles inference requests.""" | |
| if not self.is_initialized: | |
| self.initialize() | |
| if "seed" in kwargs: | |
| kwargs["generator"] = torch.Generator(self.gpu_device).manual_seed(kwargs["seed"]) | |
| del kwargs["seed"] | |
| assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V | |
| result = self.pipe(**kwargs).frames # Return the tensor directly | |
| return result | |
| _predictor = None | |
| def generate_video(prompt, seed, image=None): | |
| global _predictor | |
| if seed == -1: | |
| random.seed() | |
| seed = int(random.randrange(4294967294)) | |
| if image is None: | |
| task_type = TaskType.T2V | |
| model_id = "Skywork/SkyReels-V1-Hunyuan-T2V" | |
| kwargs = { | |
| "prompt": prompt, | |
| "height": 512, | |
| "width": 512, | |
| "num_frames": 16, | |
| "num_inference_steps": 30, | |
| "seed": seed, | |
| "guidance_scale": 7.5, | |
| "negative_prompt": "bad quality, worst quality", | |
| } | |
| else: | |
| task_type = TaskType.I2V | |
| model_id = "Skywork/SkyReels-V1-Hunyuan-I2V" | |
| seed = 43 | |
| #generator = torch.Generator(device="cuda").manual_seed(seed) | |
| kwargs = { | |
| "prompt": prompt, | |
| "image": Image.open(image), | |
| "height": 512, | |
| "width": 512, | |
| "num_frames": 97, | |
| "num_inference_steps": 30, | |
| "seed": seed, | |
| #"generator": generator, | |
| "guidance_scale": 6.0, | |
| "embedded_guidance_scale": 1.0, | |
| "negative_prompt": "Aerial view, low quality, bad hands", | |
| "cfg_for": False, | |
| } | |
| if _predictor is None: | |
| _predictor = SkyReelsVideoSingleGpuInfer( | |
| task_type=task_type, | |
| model_id=model_id, | |
| quant_model=True, | |
| is_offload=True, | |
| offload_config=OffloadConfig( | |
| high_cpu_memory=True, | |
| parameters_level=True, | |
| compiler_transformer=False, | |
| ), | |
| ) | |
| _predictor.initialize() | |
| logger.info("Predictor initialized") | |
| out_samples = [] | |
| with torch.no_grad(): | |
| output = _predictor.infer(**kwargs) | |
| #out_samples.extend(output.frames[0]) | |
| #output = (output.cpu().numpy() * 255).astype(np.uint8) | |
| #output = output.transpose(0, 2, 3, 4, 1) | |
| save_dir = f"./result" | |
| os.makedirs(save_dir, exist_ok=True) | |
| video_out_file = f"{save_dir}/{seed}.mp4" | |
| print(f"generate video, local path: {video_out_file}") | |
| export_to_video(output, video_out_file, fps=24) | |
| return video_out_file, kwargs | |
| def create_gradio_interface(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Upload Image", type="filepath") | |
| prompt = gr.Textbox(label="Input Prompt") | |
| seed = gr.Number(label="Random Seed", value=-1) | |
| with gr.Column(): | |
| submit_button = gr.Button("Generate Video") | |
| output_video = gr.Video(label="Generated Video") | |
| output_params = gr.Textbox(label="Output Parameters") | |
| submit_button.click( | |
| fn=generate_video, | |
| inputs=[prompt, seed, image], | |
| outputs=[output_video, output_params], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.queue().launch() |