Spaces:
Runtime error
Runtime error
| import argparse | |
| import datetime | |
| import inspect | |
| import os | |
| from omegaconf import OmegaConf | |
| import torch | |
| import diffusers | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from animatediff.models.unet import UNet3DConditionModel | |
| from animatediff.pipelines.pipeline_animation import AnimationPipeline | |
| from animatediff.utils.util import save_videos_grid | |
| from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint | |
| from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from einops import rearrange, repeat | |
| import csv, pdb, glob | |
| from safetensors import safe_open | |
| import math | |
| from pathlib import Path | |
| def main(args): | |
| *_, func_args = inspect.getargvalues(inspect.currentframe()) | |
| func_args = dict(func_args) | |
| time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
| savedir = f"samples/{Path(args.config).stem}-{time_str}" | |
| os.makedirs(savedir) | |
| inference_config = OmegaConf.load(args.inference_config) | |
| config = OmegaConf.load(args.config) | |
| samples = [] | |
| sample_idx = 0 | |
| for model_idx, (config_key, model_config) in enumerate(list(config.items())): | |
| motion_modules = model_config.motion_module | |
| motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules) | |
| for motion_module in motion_modules: | |
| ### >>> create validation pipeline >>> ### | |
| tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") | |
| unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) | |
| if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() | |
| else: assert False | |
| pipeline = AnimationPipeline( | |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, | |
| scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), | |
| ).to("cuda") | |
| # 1. unet ckpt | |
| # 1.1 motion module | |
| motion_module_state_dict = torch.load(motion_module, map_location="cpu") | |
| if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) | |
| missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) | |
| assert len(unexpected) == 0 | |
| # 1.2 T2I | |
| if model_config.path != "": | |
| if model_config.path.endswith(".ckpt"): | |
| state_dict = torch.load(model_config.path) | |
| pipeline.unet.load_state_dict(state_dict) | |
| elif model_config.path.endswith(".safetensors"): | |
| state_dict = {} | |
| with safe_open(model_config.path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| is_lora = all("lora" in k for k in state_dict.keys()) | |
| if not is_lora: | |
| base_state_dict = state_dict | |
| else: | |
| base_state_dict = {} | |
| with safe_open(model_config.base, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| base_state_dict[key] = f.get_tensor(key) | |
| # vae | |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) | |
| pipeline.vae.load_state_dict(converted_vae_checkpoint) | |
| # unet | |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) | |
| pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
| # text_model | |
| pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) | |
| # import pdb | |
| # pdb.set_trace() | |
| if is_lora: | |
| pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) | |
| pipeline.to("cuda") | |
| ### <<< create validation pipeline <<< ### | |
| prompts = model_config.prompt | |
| n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt | |
| random_seeds = model_config.get("seed", [-1]) | |
| random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) | |
| random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds | |
| config[config_key].random_seed = [] | |
| for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): | |
| # manually set random seed for reproduction | |
| if random_seed != -1: torch.manual_seed(random_seed) | |
| else: torch.seed() | |
| config[config_key].random_seed.append(torch.initial_seed()) | |
| print(f"current seed: {torch.initial_seed()}") | |
| print(f"sampling {prompt} ...") | |
| sample = pipeline( | |
| prompt, | |
| negative_prompt = n_prompt, | |
| num_inference_steps = model_config.steps, | |
| guidance_scale = model_config.guidance_scale, | |
| width = args.W, | |
| height = args.H, | |
| video_length = args.L, | |
| ).videos | |
| samples.append(sample) | |
| prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) | |
| save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") | |
| print(f"save to {savedir}/sample/{prompt}.gif") | |
| sample_idx += 1 | |
| samples = torch.concat(samples) | |
| save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) | |
| OmegaConf.save(config, f"{savedir}/config.yaml") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",) | |
| parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml") | |
| parser.add_argument("--config", type=str, required=True) | |
| parser.add_argument("--L", type=int, default=16 ) | |
| parser.add_argument("--W", type=int, default=512) | |
| parser.add_argument("--H", type=int, default=512) | |
| args = parser.parse_args() | |
| main(args) | |