Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import argparse | |
| import torchvision | |
| from pipeline_videogen import VideoGenPipeline | |
| from pipelines.pipeline_inversion import VideoGenInversionPipeline | |
| from diffusers.schedulers import DDIMScheduler | |
| from diffusers.models import AutoencoderKL | |
| from diffusers.models import AutoencoderKLTemporalDecoder | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| from omegaconf import OmegaConf | |
| import os, sys | |
| sys.path.append(os.path.split(sys.path[0])[0]) | |
| from utils import find_model | |
| from models import get_models | |
| import imageio | |
| import decord | |
| import numpy as np | |
| from copy import deepcopy | |
| from PIL import Image | |
| from datasets import video_transforms | |
| from torchvision import transforms | |
| from models.unet import UNet3DConditionModel | |
| from einops import repeat | |
| from utils import dct_low_pass_filter, exchanged_mixed_dct_freq | |
| def prepare_image(path, vae, transform_video, device, dtype=torch.float16): | |
| with open(path, 'rb') as f: | |
| image = Image.open(f).convert('RGB') | |
| image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0).permute(0, 3, 1, 2) | |
| image, ori_h, ori_w, crops_coords_top, crops_coords_left = transform_video(image) | |
| image = vae.encode(image.to(dtype=dtype, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor) | |
| image = image.unsqueeze(2) | |
| return image | |
| def separation_content_motion(video_clip): | |
| """ | |
| Separate content and motion in a given video. | |
| Args: | |
| video_clip: A given video clip, shape [B, C, F, H, W] | |
| Return: | |
| base_frame: Base frame, shape [B, C, 1, H, W] | |
| motions: Motions based on base frame, shape [B, C, F-1, H, W] | |
| """ | |
| # Selecting the first frame from each video in the batch as the base frame | |
| base_frame = video_clip[:, :, :1, :, :] | |
| # Calculating the motion (difference between each frame and the base frame) | |
| motions = video_clip[:, :, 1:, :, :] - base_frame | |
| return base_frame, motions | |
| class DecordInit(object): | |
| """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader.""" | |
| def __init__(self, num_threads=1): | |
| self.num_threads = num_threads | |
| self.ctx = decord.cpu(0) | |
| def __call__(self, filename): | |
| """Perform the Decord initialization. | |
| Args: | |
| results (dict): The resulting dict to be modified and passed | |
| to the next transform in pipeline. | |
| """ | |
| reader = decord.VideoReader(filename, | |
| ctx=self.ctx, | |
| num_threads=self.num_threads) | |
| return reader | |
| def __repr__(self): | |
| repr_str = (f'{self.__class__.__name__}(' | |
| f'sr={self.sr},' | |
| f'num_threads={self.num_threads})') | |
| return repr_str | |
| def main(args): | |
| # torch.manual_seed(args.seed) | |
| torch.set_grad_enabled(False) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 # torch.float16 | |
| # unet = get_models(args).to(device, dtype=torch.float16) | |
| # state_dict = find_model(args.ckpt) | |
| # unet.load_state_dict(state_dict) | |
| unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet").to(device, dtype=torch.float16) | |
| if args.enable_vae_temporal_decoder: | |
| if args.use_dct: | |
| vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float64).to(device) | |
| else: | |
| vae_for_base_content = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) | |
| vae = deepcopy(vae_for_base_content).to(dtype=dtype) | |
| else: | |
| vae_for_base_content = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae",).to(device, dtype=torch.float64) | |
| vae = deepcopy(vae_for_base_content).to(dtype=dtype) | |
| tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) | |
| # set eval mode | |
| unet.eval() | |
| vae.eval() | |
| text_encoder.eval() | |
| scheduler_inversion = DDIMScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| beta_schedule=args.beta_schedule,) | |
| scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, | |
| subfolder="scheduler", | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| # beta_end=0.017, | |
| beta_schedule=args.beta_schedule,) | |
| videogen_pipeline = VideoGenPipeline(vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler_inversion, | |
| unet=unet).to(device) | |
| videogen_pipeline_inversion = VideoGenInversionPipeline(vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| scheduler=scheduler, | |
| unet=unet).to(device) | |
| # videogen_pipeline.enable_xformers_memory_efficient_attention() | |
| # videogen_pipeline.enable_vae_slicing() | |
| transform_video = video_transforms.Compose([ | |
| video_transforms.ToTensorVideo(), | |
| video_transforms.SDXLCenterCrop((args.image_size[0], args.image_size[1])), # center crop using shor edge, then resize | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ]) | |
| # video_path = './video_editing/A_man_walking_on_the_beach.mp4' | |
| # video_path = './video_editing/a_corgi_walking_in_the_park_at_sunrise_oil_painting_style.mp4' | |
| video_path = './video_editing/test_03.mp4' | |
| video_reader = DecordInit() | |
| video = video_reader(video_path) | |
| frame_indice = np.linspace(0, 15, 16, dtype=int) | |
| video = torch.from_numpy(video.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
| video = video / 255.0 | |
| video = video * 2.0 - 1.0 | |
| latents = vae.encode(video.to(dtype=torch.float16, device=device)).latent_dist.sample().mul_(vae.config.scaling_factor).unsqueeze(0).permute(0, 2, 1, 3, 4) | |
| base_content, motion_latents = separation_content_motion(latents) | |
| # image_path = "./video_editing/a_man_walking_in_the_park.png" | |
| # image_path = "./video_editing/a_cute_corgi_walking_in_the_park.png" | |
| image_path = "./video_editing/test_03.png" | |
| if args.use_dct: | |
| edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float64).to(device) | |
| else: | |
| edit_content = prepare_image(image_path, vae_for_base_content, transform_video, device, dtype=torch.float16).to(device) | |
| if not os.path.exists(args.save_img_path): | |
| os.makedirs(args.save_img_path) | |
| # prompt_inversion = 'a man walking on the beach' | |
| # prompt_inversion = 'a corgi walking in the park at sunrise, oil painting style' | |
| # prompt_inversion = 'A girl is playing the guitar in her room' | |
| prompt_inversion = 'A man is walking inside the church' | |
| latents = videogen_pipeline_inversion(prompt_inversion, | |
| latents=motion_latents, | |
| base_content=base_content, | |
| video_length=args.video_length, | |
| height=args.image_size[0], | |
| width=args.image_size[1], | |
| num_inference_steps=args.num_sampling_steps, | |
| guidance_scale=1.0, | |
| # guidance_scale=args.guidance_scale, | |
| motion_bucket_id=args.motion_bucket_id, | |
| output_type="latent").video | |
| # prompt = 'a man walking in the park' | |
| # prompt = 'a corgi walking in the park at sunrise, oil painting style' | |
| # prompt = 'A girl is playing the guitar in her room' | |
| prompt = 'A man is walking inside the church' | |
| if args.use_dct: | |
| # filter params | |
| print("Using DCT!") | |
| edit_content_repeat = repeat(edit_content, 'b c f h w -> b c (f r) h w', r=15).contiguous() | |
| # define filter | |
| freq_filter = dct_low_pass_filter(dct_coefficients=edit_content, | |
| percentage=0.23) | |
| noise = latents.to(dtype=torch.float64) | |
| # add noise to base_content | |
| diffuse_timesteps = torch.full((1,),int(985)) | |
| diffuse_timesteps = diffuse_timesteps.long() | |
| # 3d content | |
| edit_content_noise = scheduler.add_noise( | |
| original_samples=edit_content_repeat.to(device), | |
| noise=noise, | |
| timesteps=diffuse_timesteps.to(device)) | |
| # 3d content | |
| latents = exchanged_mixed_dct_freq(noise=noise, | |
| base_content=edit_content_noise, | |
| LPF_3d=freq_filter).to(dtype=torch.float16) | |
| latents = latents.to(dtype=torch.float16) | |
| edit_content = edit_content.to(dtype=torch.float16) | |
| videos = videogen_pipeline(prompt, | |
| latents=latents, | |
| base_content=edit_content, | |
| video_length=args.video_length, | |
| height=args.image_size[0], | |
| width=args.image_size[1], | |
| num_inference_steps=args.num_sampling_steps, | |
| # guidance_scale=1.0, | |
| guidance_scale=args.guidance_scale, | |
| motion_bucket_id=args.motion_bucket_id, | |
| enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video | |
| imageio.mimwrite(args.save_img_path + prompt.replace(' ', '_') + '_%04d' % args.run_time + '-imageio.mp4', videos[0], fps=8, quality=8) # highest quality is 10, lowest is 0 | |
| print('save path {}'.format(args.save_img_path)) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="./configs/sample.yaml") | |
| args = parser.parse_args() | |
| main(OmegaConf.load(args.config)) | |