Spaces:
Runtime error
Runtime error
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import gc | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import traceback | |
| import types | |
| from contextlib import contextmanager | |
| from functools import partial | |
| import torch | |
| import torch.cuda.amp as amp | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from .modules.vace_model import VaceWanModel | |
| from .text2video import ( | |
| FlowDPMSolverMultistepScheduler, | |
| FlowUniPCMultistepScheduler, | |
| T5EncoderModel, | |
| WanT2V, | |
| WanVAE, | |
| get_sampling_sigmas, | |
| retrieve_timesteps, | |
| shard_model, | |
| ) | |
| from .utils.vace_processor import VaceVideoProcessor | |
| class WanVace(WanT2V): | |
| def __init__( | |
| self, | |
| config, | |
| checkpoint_dir, | |
| device_id=0, | |
| rank=0, | |
| t5_fsdp=False, | |
| dit_fsdp=False, | |
| use_usp=False, | |
| t5_cpu=False, | |
| ): | |
| r""" | |
| Initializes the Wan text-to-video generation model components. | |
| Args: | |
| config (EasyDict): | |
| Object containing model parameters initialized from config.py | |
| checkpoint_dir (`str`): | |
| Path to directory containing model checkpoints | |
| device_id (`int`, *optional*, defaults to 0): | |
| Id of target GPU device | |
| rank (`int`, *optional*, defaults to 0): | |
| Process rank for distributed training | |
| t5_fsdp (`bool`, *optional*, defaults to False): | |
| Enable FSDP sharding for T5 model | |
| dit_fsdp (`bool`, *optional*, defaults to False): | |
| Enable FSDP sharding for DiT model | |
| use_usp (`bool`, *optional*, defaults to False): | |
| Enable distribution strategy of USP. | |
| t5_cpu (`bool`, *optional*, defaults to False): | |
| Whether to place T5 model on CPU. Only works without t5_fsdp. | |
| """ | |
| self.device = torch.device(f"cuda:{device_id}") | |
| self.config = config | |
| self.rank = rank | |
| self.t5_cpu = t5_cpu | |
| self.num_train_timesteps = config.num_train_timesteps | |
| self.param_dtype = config.param_dtype | |
| shard_fn = partial(shard_model, device_id=device_id) | |
| self.text_encoder = T5EncoderModel( | |
| text_len=config.text_len, | |
| dtype=config.t5_dtype, | |
| device=torch.device('cpu'), | |
| checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), | |
| tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), | |
| shard_fn=shard_fn if t5_fsdp else None) | |
| self.vae_stride = config.vae_stride | |
| self.patch_size = config.patch_size | |
| self.vae = WanVAE( | |
| vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), | |
| device=self.device) | |
| logging.info(f"Creating VaceWanModel from {checkpoint_dir}") | |
| self.model = VaceWanModel.from_pretrained(checkpoint_dir) | |
| self.model.eval().requires_grad_(False) | |
| if use_usp: | |
| from xfuser.core.distributed import get_sequence_parallel_world_size | |
| from .distributed.xdit_context_parallel import ( | |
| usp_attn_forward, | |
| usp_dit_forward, | |
| usp_dit_forward_vace, | |
| ) | |
| for block in self.model.blocks: | |
| block.self_attn.forward = types.MethodType( | |
| usp_attn_forward, block.self_attn) | |
| for block in self.model.vace_blocks: | |
| block.self_attn.forward = types.MethodType( | |
| usp_attn_forward, block.self_attn) | |
| self.model.forward = types.MethodType(usp_dit_forward, self.model) | |
| self.model.forward_vace = types.MethodType(usp_dit_forward_vace, | |
| self.model) | |
| self.sp_size = get_sequence_parallel_world_size() | |
| else: | |
| self.sp_size = 1 | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if dit_fsdp: | |
| self.model = shard_fn(self.model) | |
| else: | |
| self.model.to(self.device) | |
| self.sample_neg_prompt = config.sample_neg_prompt | |
| self.vid_proc = VaceVideoProcessor( | |
| downsample=tuple( | |
| [x * y for x, y in zip(config.vae_stride, self.patch_size)]), | |
| min_area=720 * 1280, | |
| max_area=720 * 1280, | |
| min_fps=config.sample_fps, | |
| max_fps=config.sample_fps, | |
| zero_start=True, | |
| seq_len=75600, | |
| keep_last=True) | |
| def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): | |
| vae = self.vae if vae is None else vae | |
| if ref_images is None: | |
| ref_images = [None] * len(frames) | |
| else: | |
| assert len(frames) == len(ref_images) | |
| if masks is None: | |
| latents = vae.encode(frames) | |
| else: | |
| masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks] | |
| inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] | |
| reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] | |
| inactive = vae.encode(inactive) | |
| reactive = vae.encode(reactive) | |
| latents = [ | |
| torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive) | |
| ] | |
| cat_latents = [] | |
| for latent, refs in zip(latents, ref_images): | |
| if refs is not None: | |
| if masks is None: | |
| ref_latent = vae.encode(refs) | |
| else: | |
| ref_latent = vae.encode(refs) | |
| ref_latent = [ | |
| torch.cat((u, torch.zeros_like(u)), dim=0) | |
| for u in ref_latent | |
| ] | |
| assert all([x.shape[1] == 1 for x in ref_latent]) | |
| latent = torch.cat([*ref_latent, latent], dim=1) | |
| cat_latents.append(latent) | |
| return cat_latents | |
| def vace_encode_masks(self, masks, ref_images=None, vae_stride=None): | |
| vae_stride = self.vae_stride if vae_stride is None else vae_stride | |
| if ref_images is None: | |
| ref_images = [None] * len(masks) | |
| else: | |
| assert len(masks) == len(ref_images) | |
| result_masks = [] | |
| for mask, refs in zip(masks, ref_images): | |
| c, depth, height, width = mask.shape | |
| new_depth = int((depth + 3) // vae_stride[0]) | |
| height = 2 * (int(height) // (vae_stride[1] * 2)) | |
| width = 2 * (int(width) // (vae_stride[2] * 2)) | |
| # reshape | |
| mask = mask[0, :, :, :] | |
| mask = mask.view(depth, height, vae_stride[1], width, | |
| vae_stride[1]) # depth, height, 8, width, 8 | |
| mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width | |
| mask = mask.reshape(vae_stride[1] * vae_stride[2], depth, height, | |
| width) # 8*8, depth, height, width | |
| # interpolation | |
| mask = F.interpolate( | |
| mask.unsqueeze(0), | |
| size=(new_depth, height, width), | |
| mode='nearest-exact').squeeze(0) | |
| if refs is not None: | |
| length = len(refs) | |
| mask_pad = torch.zeros_like(mask[:, :length, :, :]) | |
| mask = torch.cat((mask_pad, mask), dim=1) | |
| result_masks.append(mask) | |
| return result_masks | |
| def vace_latent(self, z, m): | |
| return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] | |
| def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, | |
| image_size, device): | |
| area = image_size[0] * image_size[1] | |
| self.vid_proc.set_area(area) | |
| if area == 720 * 1280: | |
| self.vid_proc.set_seq_len(75600) | |
| elif area == 480 * 832: | |
| self.vid_proc.set_seq_len(32760) | |
| else: | |
| raise NotImplementedError( | |
| f'image_size {image_size} is not supported') | |
| image_size = (image_size[1], image_size[0]) | |
| image_sizes = [] | |
| for i, (sub_src_video, | |
| sub_src_mask) in enumerate(zip(src_video, src_mask)): | |
| if sub_src_mask is not None and sub_src_video is not None: | |
| src_video[i], src_mask[ | |
| i], _, _, _ = self.vid_proc.load_video_pair( | |
| sub_src_video, sub_src_mask) | |
| src_video[i] = src_video[i].to(device) | |
| src_mask[i] = src_mask[i].to(device) | |
| src_mask[i] = torch.clamp( | |
| (src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) | |
| image_sizes.append(src_video[i].shape[2:]) | |
| elif sub_src_video is None: | |
| src_video[i] = torch.zeros( | |
| (3, num_frames, image_size[0], image_size[1]), | |
| device=device) | |
| src_mask[i] = torch.ones_like(src_video[i], device=device) | |
| image_sizes.append(image_size) | |
| else: | |
| src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video) | |
| src_video[i] = src_video[i].to(device) | |
| src_mask[i] = torch.ones_like(src_video[i], device=device) | |
| image_sizes.append(src_video[i].shape[2:]) | |
| for i, ref_images in enumerate(src_ref_images): | |
| if ref_images is not None: | |
| image_size = image_sizes[i] | |
| for j, ref_img in enumerate(ref_images): | |
| if ref_img is not None: | |
| ref_img = Image.open(ref_img).convert("RGB") | |
| ref_img = TF.to_tensor(ref_img).sub_(0.5).div_( | |
| 0.5).unsqueeze(1) | |
| if ref_img.shape[-2:] != image_size: | |
| canvas_height, canvas_width = image_size | |
| ref_height, ref_width = ref_img.shape[-2:] | |
| white_canvas = torch.ones( | |
| (3, 1, canvas_height, canvas_width), | |
| device=device) # [-1, 1] | |
| scale = min(canvas_height / ref_height, | |
| canvas_width / ref_width) | |
| new_height = int(ref_height * scale) | |
| new_width = int(ref_width * scale) | |
| resized_image = F.interpolate( | |
| ref_img.squeeze(1).unsqueeze(0), | |
| size=(new_height, new_width), | |
| mode='bilinear', | |
| align_corners=False).squeeze(0).unsqueeze(1) | |
| top = (canvas_height - new_height) // 2 | |
| left = (canvas_width - new_width) // 2 | |
| white_canvas[:, :, top:top + new_height, | |
| left:left + new_width] = resized_image | |
| ref_img = white_canvas | |
| src_ref_images[i][j] = ref_img.to(device) | |
| return src_video, src_mask, src_ref_images | |
| def decode_latent(self, zs, ref_images=None, vae=None): | |
| vae = self.vae if vae is None else vae | |
| if ref_images is None: | |
| ref_images = [None] * len(zs) | |
| else: | |
| assert len(zs) == len(ref_images) | |
| trimed_zs = [] | |
| for z, refs in zip(zs, ref_images): | |
| if refs is not None: | |
| z = z[:, len(refs):, :, :] | |
| trimed_zs.append(z) | |
| return vae.decode(trimed_zs) | |
| def generate(self, | |
| input_prompt, | |
| input_frames, | |
| input_masks, | |
| input_ref_images, | |
| size=(1280, 720), | |
| frame_num=81, | |
| context_scale=1.0, | |
| shift=5.0, | |
| sample_solver='unipc', | |
| sampling_steps=50, | |
| guide_scale=5.0, | |
| n_prompt="", | |
| seed=-1, | |
| offload_model=True): | |
| r""" | |
| Generates video frames from text prompt using diffusion process. | |
| Args: | |
| input_prompt (`str`): | |
| Text prompt for content generation | |
| size (tupele[`int`], *optional*, defaults to (1280,720)): | |
| Controls video resolution, (width,height). | |
| frame_num (`int`, *optional*, defaults to 81): | |
| How many frames to sample from a video. The number should be 4n+1 | |
| shift (`float`, *optional*, defaults to 5.0): | |
| Noise schedule shift parameter. Affects temporal dynamics | |
| sample_solver (`str`, *optional*, defaults to 'unipc'): | |
| Solver used to sample the video. | |
| sampling_steps (`int`, *optional*, defaults to 40): | |
| Number of diffusion sampling steps. Higher values improve quality but slow generation | |
| guide_scale (`float`, *optional*, defaults 5.0): | |
| Classifier-free guidance scale. Controls prompt adherence vs. creativity | |
| n_prompt (`str`, *optional*, defaults to ""): | |
| Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` | |
| seed (`int`, *optional*, defaults to -1): | |
| Random seed for noise generation. If -1, use random seed. | |
| offload_model (`bool`, *optional*, defaults to True): | |
| If True, offloads models to CPU during generation to save VRAM | |
| Returns: | |
| torch.Tensor: | |
| Generated video frames tensor. Dimensions: (C, N H, W) where: | |
| - C: Color channels (3 for RGB) | |
| - N: Number of frames (81) | |
| - H: Frame height (from size) | |
| - W: Frame width from size) | |
| """ | |
| # preprocess | |
| # F = frame_num | |
| # target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, | |
| # size[1] // self.vae_stride[1], | |
| # size[0] // self.vae_stride[2]) | |
| # | |
| # seq_len = math.ceil((target_shape[2] * target_shape[3]) / | |
| # (self.patch_size[1] * self.patch_size[2]) * | |
| # target_shape[1] / self.sp_size) * self.sp_size | |
| if n_prompt == "": | |
| n_prompt = self.sample_neg_prompt | |
| seed = seed if seed >= 0 else random.randint(0, sys.maxsize) | |
| seed_g = torch.Generator(device=self.device) | |
| seed_g.manual_seed(seed) | |
| if not self.t5_cpu: | |
| self.text_encoder.model.to(self.device) | |
| context = self.text_encoder([input_prompt], self.device) | |
| context_null = self.text_encoder([n_prompt], self.device) | |
| if offload_model: | |
| self.text_encoder.model.cpu() | |
| else: | |
| context = self.text_encoder([input_prompt], torch.device('cpu')) | |
| context_null = self.text_encoder([n_prompt], torch.device('cpu')) | |
| context = [t.to(self.device) for t in context] | |
| context_null = [t.to(self.device) for t in context_null] | |
| # vace context encode | |
| z0 = self.vace_encode_frames( | |
| input_frames, input_ref_images, masks=input_masks) | |
| m0 = self.vace_encode_masks(input_masks, input_ref_images) | |
| z = self.vace_latent(z0, m0) | |
| target_shape = list(z0[0].shape) | |
| target_shape[0] = int(target_shape[0] / 2) | |
| noise = [ | |
| torch.randn( | |
| target_shape[0], | |
| target_shape[1], | |
| target_shape[2], | |
| target_shape[3], | |
| dtype=torch.float32, | |
| device=self.device, | |
| generator=seed_g) | |
| ] | |
| seq_len = math.ceil((target_shape[2] * target_shape[3]) / | |
| (self.patch_size[1] * self.patch_size[2]) * | |
| target_shape[1] / self.sp_size) * self.sp_size | |
| def noop_no_sync(): | |
| yield | |
| no_sync = getattr(self.model, 'no_sync', noop_no_sync) | |
| # evaluation mode | |
| with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): | |
| if sample_solver == 'unipc': | |
| sample_scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=self.num_train_timesteps, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sample_scheduler.set_timesteps( | |
| sampling_steps, device=self.device, shift=shift) | |
| timesteps = sample_scheduler.timesteps | |
| elif sample_solver == 'dpm++': | |
| sample_scheduler = FlowDPMSolverMultistepScheduler( | |
| num_train_timesteps=self.num_train_timesteps, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) | |
| timesteps, _ = retrieve_timesteps( | |
| sample_scheduler, | |
| device=self.device, | |
| sigmas=sampling_sigmas) | |
| else: | |
| raise NotImplementedError("Unsupported solver.") | |
| # sample videos | |
| latents = noise | |
| arg_c = {'context': context, 'seq_len': seq_len} | |
| arg_null = {'context': context_null, 'seq_len': seq_len} | |
| for _, t in enumerate(tqdm(timesteps)): | |
| latent_model_input = latents | |
| timestep = [t] | |
| timestep = torch.stack(timestep) | |
| self.model.to(self.device) | |
| noise_pred_cond = self.model( | |
| latent_model_input, | |
| t=timestep, | |
| vace_context=z, | |
| vace_context_scale=context_scale, | |
| **arg_c)[0] | |
| noise_pred_uncond = self.model( | |
| latent_model_input, | |
| t=timestep, | |
| vace_context=z, | |
| vace_context_scale=context_scale, | |
| **arg_null)[0] | |
| noise_pred = noise_pred_uncond + guide_scale * ( | |
| noise_pred_cond - noise_pred_uncond) | |
| temp_x0 = sample_scheduler.step( | |
| noise_pred.unsqueeze(0), | |
| t, | |
| latents[0].unsqueeze(0), | |
| return_dict=False, | |
| generator=seed_g)[0] | |
| latents = [temp_x0.squeeze(0)] | |
| x0 = latents | |
| if offload_model: | |
| self.model.cpu() | |
| torch.cuda.empty_cache() | |
| if self.rank == 0: | |
| videos = self.decode_latent(x0, input_ref_images) | |
| del noise, latents | |
| del sample_scheduler | |
| if offload_model: | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| return videos[0] if self.rank == 0 else None | |
| class WanVaceMP(WanVace): | |
| def __init__(self, | |
| config, | |
| checkpoint_dir, | |
| use_usp=False, | |
| ulysses_size=None, | |
| ring_size=None): | |
| self.config = config | |
| self.checkpoint_dir = checkpoint_dir | |
| self.use_usp = use_usp | |
| os.environ['MASTER_ADDR'] = 'localhost' | |
| os.environ['MASTER_PORT'] = '12345' | |
| os.environ['RANK'] = '0' | |
| os.environ['WORLD_SIZE'] = '1' | |
| self.in_q_list = None | |
| self.out_q = None | |
| self.inference_pids = None | |
| self.ulysses_size = ulysses_size | |
| self.ring_size = ring_size | |
| self.dynamic_load() | |
| self.device = 'cpu' if torch.cuda.is_available() else 'cpu' | |
| self.vid_proc = VaceVideoProcessor( | |
| downsample=tuple( | |
| [x * y for x, y in zip(config.vae_stride, config.patch_size)]), | |
| min_area=480 * 832, | |
| max_area=480 * 832, | |
| min_fps=self.config.sample_fps, | |
| max_fps=self.config.sample_fps, | |
| zero_start=True, | |
| seq_len=32760, | |
| keep_last=True) | |
| def dynamic_load(self): | |
| if hasattr(self, 'inference_pids') and self.inference_pids is not None: | |
| return | |
| gpu_infer = os.environ.get( | |
| 'LOCAL_WORLD_SIZE') or torch.cuda.device_count() | |
| pmi_rank = int(os.environ['RANK']) | |
| pmi_world_size = int(os.environ['WORLD_SIZE']) | |
| in_q_list = [ | |
| torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer) | |
| ] | |
| out_q = torch.multiprocessing.Manager().Queue() | |
| initialized_events = [ | |
| torch.multiprocessing.Manager().Event() for _ in range(gpu_infer) | |
| ] | |
| context = mp.spawn( | |
| self.mp_worker, | |
| nprocs=gpu_infer, | |
| args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, | |
| initialized_events, self), | |
| join=False) | |
| all_initialized = False | |
| while not all_initialized: | |
| all_initialized = all( | |
| event.is_set() for event in initialized_events) | |
| if not all_initialized: | |
| time.sleep(0.1) | |
| print('Inference model is initialized', flush=True) | |
| self.in_q_list = in_q_list | |
| self.out_q = out_q | |
| self.inference_pids = context.pids() | |
| self.initialized_events = initialized_events | |
| def transfer_data_to_cuda(self, data, device): | |
| if data is None: | |
| return None | |
| else: | |
| if isinstance(data, torch.Tensor): | |
| data = data.to(device) | |
| elif isinstance(data, list): | |
| data = [ | |
| self.transfer_data_to_cuda(subdata, device) | |
| for subdata in data | |
| ] | |
| elif isinstance(data, dict): | |
| data = { | |
| key: self.transfer_data_to_cuda(val, device) | |
| for key, val in data.items() | |
| } | |
| return data | |
| def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, | |
| out_q, initialized_events, work_env): | |
| try: | |
| world_size = pmi_world_size * gpu_infer | |
| rank = pmi_rank * gpu_infer + gpu | |
| print("world_size", world_size, "rank", rank, flush=True) | |
| torch.cuda.set_device(gpu) | |
| dist.init_process_group( | |
| backend='nccl', | |
| init_method='env://', | |
| rank=rank, | |
| world_size=world_size) | |
| from xfuser.core.distributed import ( | |
| init_distributed_environment, | |
| initialize_model_parallel, | |
| ) | |
| init_distributed_environment( | |
| rank=dist.get_rank(), world_size=dist.get_world_size()) | |
| initialize_model_parallel( | |
| sequence_parallel_degree=dist.get_world_size(), | |
| ring_degree=self.ring_size or 1, | |
| ulysses_degree=self.ulysses_size or 1) | |
| num_train_timesteps = self.config.num_train_timesteps | |
| param_dtype = self.config.param_dtype | |
| shard_fn = partial(shard_model, device_id=gpu) | |
| text_encoder = T5EncoderModel( | |
| text_len=self.config.text_len, | |
| dtype=self.config.t5_dtype, | |
| device=torch.device('cpu'), | |
| checkpoint_path=os.path.join(self.checkpoint_dir, | |
| self.config.t5_checkpoint), | |
| tokenizer_path=os.path.join(self.checkpoint_dir, | |
| self.config.t5_tokenizer), | |
| shard_fn=shard_fn if True else None) | |
| text_encoder.model.to(gpu) | |
| vae_stride = self.config.vae_stride | |
| patch_size = self.config.patch_size | |
| vae = WanVAE( | |
| vae_pth=os.path.join(self.checkpoint_dir, | |
| self.config.vae_checkpoint), | |
| device=gpu) | |
| logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") | |
| model = VaceWanModel.from_pretrained(self.checkpoint_dir) | |
| model.eval().requires_grad_(False) | |
| if self.use_usp: | |
| from xfuser.core.distributed import get_sequence_parallel_world_size | |
| from .distributed.xdit_context_parallel import ( | |
| usp_attn_forward, | |
| usp_dit_forward, | |
| usp_dit_forward_vace, | |
| ) | |
| for block in model.blocks: | |
| block.self_attn.forward = types.MethodType( | |
| usp_attn_forward, block.self_attn) | |
| for block in model.vace_blocks: | |
| block.self_attn.forward = types.MethodType( | |
| usp_attn_forward, block.self_attn) | |
| model.forward = types.MethodType(usp_dit_forward, model) | |
| model.forward_vace = types.MethodType(usp_dit_forward_vace, | |
| model) | |
| sp_size = get_sequence_parallel_world_size() | |
| else: | |
| sp_size = 1 | |
| dist.barrier() | |
| model = shard_fn(model) | |
| sample_neg_prompt = self.config.sample_neg_prompt | |
| torch.cuda.empty_cache() | |
| event = initialized_events[gpu] | |
| in_q = in_q_list[gpu] | |
| event.set() | |
| while True: | |
| item = in_q.get() | |
| input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \ | |
| shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item | |
| input_frames = self.transfer_data_to_cuda(input_frames, gpu) | |
| input_masks = self.transfer_data_to_cuda(input_masks, gpu) | |
| input_ref_images = self.transfer_data_to_cuda( | |
| input_ref_images, gpu) | |
| if n_prompt == "": | |
| n_prompt = sample_neg_prompt | |
| seed = seed if seed >= 0 else random.randint(0, sys.maxsize) | |
| seed_g = torch.Generator(device=gpu) | |
| seed_g.manual_seed(seed) | |
| context = text_encoder([input_prompt], gpu) | |
| context_null = text_encoder([n_prompt], gpu) | |
| # vace context encode | |
| z0 = self.vace_encode_frames( | |
| input_frames, input_ref_images, masks=input_masks, vae=vae) | |
| m0 = self.vace_encode_masks( | |
| input_masks, input_ref_images, vae_stride=vae_stride) | |
| z = self.vace_latent(z0, m0) | |
| target_shape = list(z0[0].shape) | |
| target_shape[0] = int(target_shape[0] / 2) | |
| noise = [ | |
| torch.randn( | |
| target_shape[0], | |
| target_shape[1], | |
| target_shape[2], | |
| target_shape[3], | |
| dtype=torch.float32, | |
| device=gpu, | |
| generator=seed_g) | |
| ] | |
| seq_len = math.ceil((target_shape[2] * target_shape[3]) / | |
| (patch_size[1] * patch_size[2]) * | |
| target_shape[1] / sp_size) * sp_size | |
| def noop_no_sync(): | |
| yield | |
| no_sync = getattr(model, 'no_sync', noop_no_sync) | |
| # evaluation mode | |
| with amp.autocast( | |
| dtype=param_dtype), torch.no_grad(), no_sync(): | |
| if sample_solver == 'unipc': | |
| sample_scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=num_train_timesteps, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sample_scheduler.set_timesteps( | |
| sampling_steps, device=gpu, shift=shift) | |
| timesteps = sample_scheduler.timesteps | |
| elif sample_solver == 'dpm++': | |
| sample_scheduler = FlowDPMSolverMultistepScheduler( | |
| num_train_timesteps=num_train_timesteps, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sampling_sigmas = get_sampling_sigmas( | |
| sampling_steps, shift) | |
| timesteps, _ = retrieve_timesteps( | |
| sample_scheduler, | |
| device=gpu, | |
| sigmas=sampling_sigmas) | |
| else: | |
| raise NotImplementedError("Unsupported solver.") | |
| # sample videos | |
| latents = noise | |
| arg_c = {'context': context, 'seq_len': seq_len} | |
| arg_null = {'context': context_null, 'seq_len': seq_len} | |
| for _, t in enumerate(tqdm(timesteps)): | |
| latent_model_input = latents | |
| timestep = [t] | |
| timestep = torch.stack(timestep) | |
| model.to(gpu) | |
| noise_pred_cond = model( | |
| latent_model_input, | |
| t=timestep, | |
| vace_context=z, | |
| vace_context_scale=context_scale, | |
| **arg_c)[0] | |
| noise_pred_uncond = model( | |
| latent_model_input, | |
| t=timestep, | |
| vace_context=z, | |
| vace_context_scale=context_scale, | |
| **arg_null)[0] | |
| noise_pred = noise_pred_uncond + guide_scale * ( | |
| noise_pred_cond - noise_pred_uncond) | |
| temp_x0 = sample_scheduler.step( | |
| noise_pred.unsqueeze(0), | |
| t, | |
| latents[0].unsqueeze(0), | |
| return_dict=False, | |
| generator=seed_g)[0] | |
| latents = [temp_x0.squeeze(0)] | |
| torch.cuda.empty_cache() | |
| x0 = latents | |
| if rank == 0: | |
| videos = self.decode_latent( | |
| x0, input_ref_images, vae=vae) | |
| del noise, latents | |
| del sample_scheduler | |
| if offload_model: | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| if rank == 0: | |
| out_q.put(videos[0].cpu()) | |
| except Exception as e: | |
| trace_info = traceback.format_exc() | |
| print(trace_info, flush=True) | |
| print(e, flush=True) | |
| def generate(self, | |
| input_prompt, | |
| input_frames, | |
| input_masks, | |
| input_ref_images, | |
| size=(1280, 720), | |
| frame_num=81, | |
| context_scale=1.0, | |
| shift=5.0, | |
| sample_solver='unipc', | |
| sampling_steps=50, | |
| guide_scale=5.0, | |
| n_prompt="", | |
| seed=-1, | |
| offload_model=True): | |
| input_data = (input_prompt, input_frames, input_masks, input_ref_images, | |
| size, frame_num, context_scale, shift, sample_solver, | |
| sampling_steps, guide_scale, n_prompt, seed, | |
| offload_model) | |
| for in_q in self.in_q_list: | |
| in_q.put(input_data) | |
| value_output = self.out_q.get() | |
| return value_output | |