Spaces:
Runtime error
Runtime error
import argparse, os, sys, glob, yaml, math, random | |
import datetime, time | |
import numpy as np | |
from omegaconf import OmegaConf | |
from collections import OrderedDict | |
from tqdm import trange, tqdm | |
from einops import repeat | |
from einops import rearrange, repeat | |
from functools import partial | |
import torch | |
from pytorch_lightning import seed_everything | |
from .funcs import load_model_checkpoint, load_image_batch, get_filelist, save_videos | |
from .funcs import batch_ddim_sampling | |
from .utils import instantiate_from_config | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--seed", type=int, default=20230211, help="seed for seed_everything") | |
parser.add_argument("--mode", default="base", type=str, help="which kind of inference mode: {'base', 'i2v'}") | |
parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path") | |
parser.add_argument("--config", type=str, help="config (yaml) path") | |
parser.add_argument("--savefps", type=str, default=10, help="video fps to generate") | |
parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt", ) | |
parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM", ) | |
parser.add_argument("--ddim_eta", type=float, default=1.0, | |
help="eta for ddim sampling (0.0 yields deterministic sampling)", ) | |
parser.add_argument("--bs", type=int, default=1, help="batch size for inference") | |
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space") | |
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space") | |
parser.add_argument("--frames", type=int, default=-1, help="frames num to inference") | |
parser.add_argument("--fps", type=int, default=24) | |
parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, | |
help="prompt classifier-free guidance") | |
parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, | |
help="temporal consistency guidance") | |
## for conditional i2v only | |
# parser.add_argument("--cond_input", type=str, default=None, help="data dir of conditional input") | |
return parser | |
class VideoCrafterPipeline(): | |
def __init__(self, arg_list, device, rank: int = 0, gpu_num: int = 1): | |
""" | |
Initialize the pipeline of videocrafter. | |
It is always on one GPU. | |
Args: | |
arg_list: The parameters needed for the model. | |
device: | |
rank: | |
gpu_num: | |
""" | |
parser = get_parser() | |
self.args = parser.parse_args(args=arg_list) | |
self.gpu_no, self.gpu_num = rank, gpu_num | |
_dict = {'model': {'target': 'lvdm.models.ddpm3d.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'timesteps': 1000, 'first_stage_key': 'video', 'cond_stage_key': 'caption', 'cond_stage_trainable': False, 'conditioning_key': 'crossattn', 'image_size': [40, 64], 'channels': 4, 'scale_by_std': False, 'scale_factor': 0.18215, 'use_ema': False, 'uncond_type': 'empty_seq', 'use_scale': True, 'scale_b': 0.7, 'unet_config': {'target': 'lvdm.modules.networks.openaimodel3d.UNetModel', 'params': {'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_head_channels': 64, 'transformer_depth': 1, 'context_dim': 1024, 'use_linear': True, 'use_checkpoint': True, 'temporal_conv': True, 'temporal_attention': True, 'temporal_selfatt_only': True, 'use_relative_position': False, 'use_causal_attention': False, 'temporal_length': 16, 'addition_attention': True, 'fps_cond': True}}, 'first_stage_config': {'target': 'lvdm.models.autoencoder.AutoencoderKL', 'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 512, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}, 'lossconfig': {'target': 'torch.nn.Identity'}}}, 'cond_stage_config': {'target': 'lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder', 'params': {'freeze': True, 'layer': 'penultimate'}}}}} | |
config = OmegaConf.create(_dict) | |
#config = OmegaConf.load(self.args.config) | |
# data_config = config.pop("data", OmegaConf.create()) | |
model_config = config.pop("model", OmegaConf.create()) | |
model = instantiate_from_config(model_config) | |
model = model.cuda(self.gpu_no) | |
print("About to load model") | |
assert os.path.exists(self.args.ckpt_path), f"Error: checkpoint [{self.args.ckpt_path}] Not Found!" | |
self.model = load_model_checkpoint(model, self.args.ckpt_path) | |
self.model.eval() | |
def run_inference(self, prompt, video_length, height, width, **kwargs): | |
""" | |
https://github.com/AILab-CVC/VideoCrafter | |
Generate video from the provided text prompt. | |
Args: | |
prompt: The provided text prompt. | |
video_length: The length (num of frames) of the generated video. | |
height: The height of the video frame. | |
width: The width of the video frame. | |
**kwargs: | |
Returns: | |
The generated video represented as tensor with shape (1, 1, channels, height, width, num of frames) | |
""" | |
## step 1: model config | |
## ----------------------------------------------------------------- | |
## sample shape | |
assert (self.args.height % 16 == 0) and ( | |
self.args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" | |
## latent noise shape | |
h, w = height // 8, width // 8 | |
frames = video_length | |
channels = self.model.channels | |
## step 2: load data | |
## ----------------------------------------------------------------- | |
prompt_list = [prompt] | |
num_samples = len(prompt_list) | |
# filename_list = [f"{id + 1:04d}" for id in range(num_samples)] | |
gpu_num = self.gpu_num | |
gpu_no = self.gpu_no | |
samples_split = num_samples // gpu_num | |
residual_tail = num_samples % gpu_num | |
print(f'[rank:{gpu_no}] {samples_split}/{num_samples} samples loaded.') | |
indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1))) | |
if gpu_no == 0 and residual_tail != 0: | |
indices = indices + list(range(num_samples - residual_tail, num_samples)) | |
prompt_list_rank = [prompt_list[i] for i in indices] | |
# # conditional input | |
# if self.args.mode == "i2v": | |
# ## each video or frames dir per prompt | |
# cond_inputs = get_filelist(self.args.cond_input, ext='[mpj][pn][4gj]') # '[mpj][pn][4gj]' | |
# assert len( | |
# cond_inputs) == num_samples, f"Error: conditional input ({len(cond_inputs)}) NOT match prompt ({num_samples})!" | |
# filename_list = [f"{os.path.split(cond_inputs[id])[-1][:-4]}" for id in range(num_samples)] | |
# cond_inputs_rank = [cond_inputs[i] for i in indices] | |
# filename_list_rank = [filename_list[i] for i in indices] | |
## step 3: run over samples | |
## ----------------------------------------------------------------- | |
# start = time.time() | |
n_rounds = len(prompt_list_rank) // self.args.bs | |
n_rounds = n_rounds + 1 if len(prompt_list_rank) % self.args.bs != 0 else n_rounds | |
for idx in range(0, n_rounds): | |
print(f'[rank:{gpu_no}] batch-{idx + 1} ({self.args.bs})x{self.args.n_samples} ...') | |
idx_s = idx * self.args.bs | |
idx_e = min(idx_s + self.args.bs, len(prompt_list_rank)) | |
batch_size = idx_e - idx_s | |
# filenames = filename_list_rank[idx_s:idx_e] | |
noise_shape = [batch_size, channels, frames, h, w] | |
fps = torch.tensor([self.args.fps] * batch_size).to(self.model.device).long() | |
prompts = prompt_list_rank[idx_s:idx_e] | |
if isinstance(prompts, str): | |
prompts = [prompts] | |
# prompts = batch_size * [""] | |
text_emb = self.model.get_learned_conditioning(prompts) | |
if self.args.mode == 'base': | |
cond = {"c_crossattn": [text_emb], "fps": fps} | |
# elif self.args.mode == 'i2v': | |
# # cond_images = torch.zeros(noise_shape[0],3,224,224).to(model.device) | |
# cond_images = load_image_batch(cond_inputs_rank[idx_s:idx_e], (self.args.height, self.args.width)) | |
# cond_images = cond_images.to(self.model.device) | |
# img_emb = self.model.get_image_embeds(cond_images) | |
# imtext_cond = torch.cat([text_emb, img_emb], dim=1) | |
# cond = {"c_crossattn": [imtext_cond], "fps": fps} | |
else: | |
raise NotImplementedError | |
## inference | |
batch_samples = batch_ddim_sampling(self.model, cond, noise_shape, self.args.n_samples, | |
self.args.ddim_steps, | |
self.args.ddim_eta, | |
self.args.unconditional_guidance_scale, **kwargs) | |
return batch_samples | |