import spaces import gradio as gr import argparse import sys import time import os import random from skyreelsinfer.offload import OffloadConfig from skyreelsinfer import TaskType from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer from diffusers.utils import export_to_video from diffusers.utils import load_image from PIL import Image import torch torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False torch.backends.cuda.preferred_blas_library="cublas" torch.backends.cuda.preferred_linalg_library="cusolver" torch.set_float32_matmul_precision("highest") os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1") os.environ["SAFETENSORS_FAST_GPU"] = "1" os.putenv("TOKENIZERS_PARALLELISM","False") def init_predictor(): global pipe pipe = SkyReelsVideoSingleGpuInfer( task_type= TaskType.I2V, model_id="Skywork/SkyReels-V1-Hunyuan-I2V", quant_model=False, is_offload=False, offload_config=OffloadConfig( high_cpu_memory=True, parameters_level=True, compiler_transformer=False, ) ) @spaces.GPU(duration=60) def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ): random.seed(time.time()) seed = int(random.randrange(4294967294)) if segment==1: prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt( prompt=prompt, prompt_2=prompt, device=device ) pipe.scheduler.set_timesteps(num_inference_steps, device=torch.device('cuda')) timesteps = pipe.scheduler.timesteps all_timesteps_cpu = timesteps.cpu() timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) segment_timesteps = torch.from_numpy(timesteps_split_np[0]).to("cuda") num_channels_latents = pipe.transformer.config.in_channels num_channels_latents = int(num_channels_latents / 2) image = pipe.video_processor.preprocess(image, height=height, width=width).to( device, dtype=prompt_embeds.dtype ) num_latent_frames = (frames - 1) // pipe.vae_scale_factor_temporal + 1 latents = pipe.prepare_latents( batch_size=1, num_channels_latents=pipe.transformer.config.in_channels, height=height, width=width, num_frames=frames, dtype=torch.float32, device=device, generator=generator, latents=None, ) image_latents = pipe.image_latents( image, batch_size, height, width, device, torch.float32, num_channels_latents, num_latent_frames ) image_latents = image_latents.to(pipe.transformer.dtype) guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 else: state_file = f"rv_L_{segment-1}_{seed}.pt" state = torch.load(state_file, weights_only=False) generator = torch.Generator(device='cuda').manual_seed(seed) latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16) guidance_scale = state["guidance_scale"] all_timesteps_cpu = state["all_timesteps"] height = state["height"] width = state["width"] pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device) timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8) segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda") prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16) pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16) prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16) image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16) for i, t in enumerate(pipe.progress_bar(segment_timesteps)): latents = latents.to(transformer_dtype) latent_model_input = torch.cat([latents] * 2) latent_image_input = ( torch.cat([image_latents] * 2) ) latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) with torch.no_grad(): noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_attention_mask=prompt_attention_mask, pooled_projections=pooled_prompt_embeds, guidance=guidance, # attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] intermediate_latents_cpu = latents.detach().cpu() if segment==8: latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # return HunyuanVideoPipelineOutput(frames=video) save_dir = f"./" video_out_file = f"{save_dir}/{seed}.mp4" print(f"generate video, local path: {video_out_file}") export_to_video(output, video_out_file, fps=24) return video_out_file, seed else: original_prompt_embeds_cpu = prompt_embeds.cpu() original_image_latents_cpu = image_latents.cpu() original_pooled_prompt_embeds_cpu = pooled_prompt_embeds.cpu() original_prompt_attention_mask_cpu = prompt_attention_mask.cpu() original_add_time_ids_cpu = add_time_ids.cpu() timesteps = pipe.scheduler.timesteps all_timesteps_cpu = timesteps.cpu() # Move to CPU state = { "intermediate_latents": intermediate_latents_cpu, "all_timesteps": all_timesteps_cpu, # Save full list generated by scheduler "prompt_embeds": original_prompt_embeds_cpu, # Save ORIGINAL embeds "image_latents": original_image_latents_cpu, "pooled_prompt_embeds": original_pooled_prompt_embeds_cpu, "prompt_attention_mask": original_prompt_attention_mask_cpu, "add_time_ids": original_add_time_ids_cpu, # Save ORIGINAL time IDs "guidance_scale": guidance_scale, "timesteps_split": timesteps_split_for_state, "seed": seed, "prompt": prompt, # Save originals for reference/verification "negative_prompt": negative_prompt, "height": height, # Save dimensions used "width": width } state_file = f"SkyReel_{segment}_{seed}.pt" torch.save(state, state_file) return None, seed def update_ranges(total_steps): """Calculates and updates the ranges for the 8 slave sliders.""" step_size = total_steps // 8 # Calculate the size of each segment ranges = [] for i in range(8): lower_bound = i * step_size ranges.append([lower_bound]) # Add the range to the list return ranges with gr.Blocks() as demo: with gr.Row(): image = gr.Image(label="Upload Image", type="filepath") prompt = gr.Textbox(label="Input Prompt") run_button_1 = gr.Button("Run Segment 1", scale=0) run_button_2 = gr.Button("Run Segment 2", scale=0) run_button_3 = gr.Button("Run Segment 3", scale=0) run_button_4 = gr.Button("Run Segment 4", scale=0) run_button_5 = gr.Button("Run Segment 5", scale=0) run_button_6 = gr.Button("Run Segment 6", scale=0) run_button_7 = gr.Button("Run Segment 7", scale=0) run_button_8 = gr.Button("Run Segment 8", scale=0) result = gr.Gallery(label="Result", columns=1, show_label=False) seed = gr.Number(value=1, label="Seed") size = gr.Slider( label="Size", minimum=256, maximum=1024, step=16, value=368, ) frames = gr.Slider( label="Number of Frames", minimum=16, maximum=256, step=8, value=64, ) steps = gr.Slider( label="Number of Steps", minimum=1, maximum=96, step=1, value=25, ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=16.0, step=.1, value=6.0, ) submit_button = gr.Button("Generate Video") output_video = gr.Video(label="Generated Video") range_sliders = [] for i in range(8): slider = gr.Slider( minimum=1, maximum=250, value=[i * (steps.value // 8)], step=1, label=f"Range {i + 1}", ) range_sliders.append(slider) steps.change( update_ranges, inputs=steps, outputs=range_sliders, ) gr.on( triggers=[ run_button_1.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_2.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_3.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_4.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_5.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_6.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_7.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) gr.on( triggers=[ run_button_8.click, ], fn=generate, inputs=[ gr.Number(value=4), image, prompt, size, guidance_scale, num_inference_steps, frames, seed, ], outputs=[result, seed], ) if __name__ == "__main__": init_predictor() demo.launch()