Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
''' | |
/* | |
*Copyright (c) 2021, Alibaba Group; | |
*Licensed under the Apache License, Version 2.0 (the "License"); | |
*you may not use this file except in compliance with the License. | |
*You may obtain a copy of the License at | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
*Unless required by applicable law or agreed to in writing, software | |
*distributed under the License is distributed on an "AS IS" BASIS, | |
*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
*See the License for the specific language governing permissions and | |
*limitations under the License. | |
*/ | |
''' | |
import os | |
import re | |
import os.path as osp | |
import sys | |
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) | |
import json | |
import math | |
import torch | |
import pynvml | |
import logging | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
import torch.cuda.amp as amp | |
from importlib import reload | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
import random | |
from einops import rearrange | |
import torchvision.transforms as T | |
import torchvision.transforms.functional as TF | |
from torch.nn.parallel import DistributedDataParallel | |
import utils.transforms as data | |
from ..modules.config import cfg | |
from utils.seed import setup_seed | |
from utils.multi_port import find_free_port | |
from utils.assign_cfg import assign_signle_cfg | |
from utils.distributed import generalized_all_gather, all_reduce | |
from utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col | |
from tools.modules.autoencoder import get_first_stage_encoding | |
from utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION | |
from copy import copy | |
import cv2 | |
def inference_unianimate_entrance(cfg_update, **kwargs): | |
for k, v in cfg_update.items(): | |
if isinstance(v, dict) and k in cfg: | |
cfg[k].update(v) | |
else: | |
cfg[k] = v | |
if not 'MASTER_ADDR' in os.environ: | |
os.environ['MASTER_ADDR']='localhost' | |
os.environ['MASTER_PORT']= find_free_port() | |
cfg.pmi_rank = int(os.getenv('RANK', 0)) | |
cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) | |
if cfg.debug: | |
cfg.gpus_per_machine = 1 | |
cfg.world_size = 1 | |
else: | |
cfg.gpus_per_machine = torch.cuda.device_count() | |
cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine | |
if cfg.world_size == 1: | |
worker(0, cfg, cfg_update) | |
else: | |
mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) | |
return cfg | |
def make_masked_images(imgs, masks): | |
masked_imgs = [] | |
for i, mask in enumerate(masks): | |
# concatenation | |
masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) | |
return torch.stack(masked_imgs, dim=0) | |
def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): | |
for _ in range(5): | |
try: | |
dwpose_all = {} | |
frames_all = {} | |
for ii_index in sorted(os.listdir(pose_file_path)): | |
if ii_index != "ref_pose.jpg": | |
dwpose_all[ii_index] = Image.open(os.path.join(pose_file_path, ii_index)) | |
frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path), cv2.COLOR_BGR2RGB)) | |
pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg")) | |
# Sample max_frames poses for video generation | |
stride = frame_interval | |
total_frame_num = len(frames_all) | |
cover_frame_num = (stride * (max_frames - 1) + 1) | |
if total_frame_num < cover_frame_num: | |
print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed') | |
start_frame = 0 | |
end_frame = total_frame_num | |
stride = max((total_frame_num - 1) // (max_frames - 1), 1) | |
end_frame = stride * max_frames | |
else: | |
start_frame = 0 | |
end_frame = start_frame + cover_frame_num | |
frame_list = [] | |
dwpose_list = [] | |
random_ref_frame = frames_all[list(frames_all.keys())[0]] | |
if random_ref_frame.mode != 'RGB': | |
random_ref_frame = random_ref_frame.convert('RGB') | |
random_ref_dwpose = pose_ref | |
if random_ref_dwpose.mode != 'RGB': | |
random_ref_dwpose = random_ref_dwpose.convert('RGB') | |
for i_index in range(start_frame, end_frame, stride): | |
if i_index < len(frames_all): # Check index within bounds | |
i_key = list(frames_all.keys())[i_index] | |
i_frame = frames_all[i_key] | |
if i_frame.mode != 'RGB': | |
i_frame = i_frame.convert('RGB') | |
i_dwpose = dwpose_all[i_key] | |
if i_dwpose.mode != 'RGB': | |
i_dwpose = i_dwpose.convert('RGB') | |
frame_list.append(i_frame) | |
dwpose_list.append(i_dwpose) | |
if frame_list: | |
middle_indix = 0 | |
ref_frame = frame_list[middle_indix] | |
vit_frame = vit_transforms(ref_frame) | |
random_ref_frame_tmp = train_trans_pose(random_ref_frame) | |
random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) | |
misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) | |
video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) | |
dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) | |
video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) | |
video_data[:len(frame_list), ...] = video_data_tmp | |
misc_data[:len(frame_list), ...] = misc_data_tmp | |
dwpose_data[:len(frame_list), ...] = dwpose_data_tmp | |
random_ref_frame_data[:, ...] = random_ref_frame_tmp | |
random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp | |
return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data | |
except Exception as e: | |
logging.info(f'Error reading video frame: {e}') | |
continue | |
return None, None, None, None, None, None | |
def worker(gpu, cfg, cfg_update): | |
''' | |
Inference worker for each gpu | |
''' | |
for k, v in cfg_update.items(): | |
if isinstance(v, dict) and k in cfg: | |
cfg[k].update(v) | |
else: | |
cfg[k] = v | |
cfg.gpu = gpu | |
cfg.seed = int(cfg.seed) | |
cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu | |
setup_seed(cfg.seed + cfg.rank) | |
if not cfg.debug: | |
torch.cuda.set_device(gpu) | |
torch.backends.cudnn.benchmark = True | |
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
torch.backends.cudnn.benchmark = False | |
dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank) | |
# [Log] Save logging and make log dir | |
log_dir = generalized_all_gather(cfg.log_dir)[0] | |
inf_name = osp.basename(cfg.cfg_file).split('.')[0] | |
test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] | |
cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) | |
os.makedirs(cfg.log_dir, exist_ok=True) | |
log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) | |
cfg.log_file = log_file | |
reload(logging) | |
logging.basicConfig( | |
level=logging.INFO, | |
format='[%(asctime)s] %(levelname)s: %(message)s', | |
handlers=[ | |
logging.FileHandler(filename=log_file), | |
logging.StreamHandler(stream=sys.stdout)]) | |
logging.info(cfg) | |
logging.info(f"Running UniAnimate inference on gpu {gpu}") | |
# [Diffusion] | |
diffusion = DIFFUSION.build(cfg.Diffusion) | |
# [Data] Data Transform | |
train_trans = data.Compose([ | |
data.Resize(cfg.resolution), | |
data.ToTensor(), | |
data.Normalize(mean=cfg.mean, std=cfg.std) | |
]) | |
train_trans_pose = data.Compose([ | |
data.Resize(cfg.resolution), | |
data.ToTensor(), | |
] | |
) | |
vit_transforms = T.Compose([ | |
data.Resize(cfg.vit_resolution), | |
T.ToTensor(), | |
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) | |
# [Model] embedder | |
clip_encoder = EMBEDDER.build(cfg.embedder) | |
clip_encoder.model.to(gpu) | |
with torch.no_grad(): | |
_, _, zero_y = clip_encoder(text="") | |
# [Model] auotoencoder | |
autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) | |
autoencoder.eval() # freeze | |
for param in autoencoder.parameters(): | |
param.requires_grad = False | |
autoencoder.cuda() | |
# [Model] UNet | |
if "config" in cfg.UNet: | |
cfg.UNet["config"] = cfg | |
cfg.UNet["zero_y"] = zero_y | |
model = MODEL.build(cfg.UNet) | |
state_dict = torch.load(cfg.test_model, map_location='cpu') | |
if 'state_dict' in state_dict: | |
state_dict = state_dict['state_dict'] | |
if 'step' in state_dict: | |
resume_step = state_dict['step'] | |
else: | |
resume_step = 0 | |
status = model.load_state_dict(state_dict, strict=True) | |
logging.info('Load model from {} with status {}'.format(cfg.test_model, status)) | |
model = model.to(gpu) | |
model.eval() | |
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
model.to(torch.float16) | |
else: | |
model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model | |
torch.cuda.empty_cache() | |
test_list = cfg.test_list_path | |
num_videos = len(test_list) | |
logging.info(f'There are {num_videos} videos. with {cfg.round} times') | |
# test_list = [item for item in test_list for _ in range(cfg.round)] | |
test_list = [item for _ in range(cfg.round) for item in test_list] | |
for idx, file_path in enumerate(test_list): | |
cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] | |
manual_seed = int(cfg.seed + cfg.rank + idx//num_videos) | |
setup_seed(manual_seed) | |
logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") | |
vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution) | |
misc_data = misc_data.unsqueeze(0).to(gpu) | |
vit_frame = vit_frame.unsqueeze(0).to(gpu) | |
dwpose_data = dwpose_data.unsqueeze(0).to(gpu) | |
random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) | |
random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) | |
### save for visualization | |
misc_backups = copy(misc_data) | |
frames_num = misc_data.shape[1] | |
misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') | |
mv_data_video = [] | |
### local image (first frame) | |
image_local = [] | |
if 'local_image' in cfg.video_compositions: | |
frames_num = misc_data.shape[1] | |
bs_vd_local = misc_data.shape[0] | |
image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) | |
image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) | |
image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) | |
if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: | |
with torch.no_grad(): | |
temporal_length = frames_num | |
encoder_posterior = autoencoder.encode(video_data[:,0]) | |
local_image_data = get_first_stage_encoding(encoder_posterior).detach() | |
image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] | |
### encode the video_data | |
bs_vd = misc_data.shape[0] | |
misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') | |
misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) | |
with torch.no_grad(): | |
random_ref_frame = [] | |
if 'randomref' in cfg.video_compositions: | |
random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') | |
if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: | |
temporal_length = random_ref_frame_data.shape[1] | |
encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) | |
random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() | |
random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] | |
random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') | |
if 'dwpose' in cfg.video_compositions: | |
bs_vd_local = dwpose_data.shape[0] | |
dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) | |
if 'randomref_pose' in cfg.video_compositions: | |
dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) | |
dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) | |
y_visual = [] | |
if 'image' in cfg.video_compositions: | |
with torch.no_grad(): | |
vit_frame = vit_frame.squeeze(1) | |
y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] | |
y_visual0 = y_visual.clone() | |
with amp.autocast(enabled=True): | |
pynvml.nvmlInit() | |
handle=pynvml.nvmlDeviceGetHandleByIndex(0) | |
meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) | |
cur_seed = torch.initial_seed() | |
logging.info(f"Current seed {cur_seed} ...") | |
noise = torch.randn([1, 4, cfg.max_frames, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) | |
noise = noise.to(gpu) | |
if hasattr(cfg.Diffusion, "noise_strength"): | |
b, c, f, _, _= noise.shape | |
offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) | |
noise = noise + cfg.Diffusion.noise_strength * offset_noise | |
# add a noise prior | |
noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise) | |
# construct model inputs (CFG) | |
full_model_kwargs=[{ | |
'y': None, | |
"local_image": None if len(image_local) == 0 else image_local[:], | |
'image': None if len(y_visual) == 0 else y_visual0[:], | |
'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], | |
'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], | |
}, | |
{ | |
'y': None, | |
"local_image": None, | |
'image': None, | |
'randomref': None, | |
'dwpose': None, | |
}] | |
# for visualization | |
full_model_kwargs_vis =[{ | |
'y': None, | |
"local_image": None if len(image_local) == 0 else image_local_clone[:], | |
'image': None, | |
'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], | |
'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], | |
}, | |
{ | |
'y': None, | |
"local_image": None, | |
'image': None, | |
'randomref': None, | |
'dwpose': None, | |
}] | |
partial_keys = [ | |
['image', 'randomref', "dwpose"], | |
] | |
if hasattr(cfg, "partial_keys") and cfg.partial_keys: | |
partial_keys = cfg.partial_keys | |
for partial_keys_one in partial_keys: | |
model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, | |
full_model_kwargs = full_model_kwargs, | |
use_fps_condition = cfg.use_fps_condition) | |
model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, | |
full_model_kwargs = full_model_kwargs_vis, | |
use_fps_condition = cfg.use_fps_condition) | |
noise_one = noise | |
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
clip_encoder.cpu() # add this line | |
autoencoder.cpu() # add this line | |
torch.cuda.empty_cache() # add this line | |
video_data = diffusion.ddim_sample_loop( | |
noise=noise_one, | |
model=model.eval(), | |
model_kwargs=model_kwargs_one, | |
guide_scale=cfg.guide_scale, | |
ddim_timesteps=cfg.ddim_timesteps, | |
eta=0.0) | |
if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: | |
# if run forward of autoencoder or clip_encoder second times, load them again | |
clip_encoder.cuda() | |
autoencoder.cuda() | |
video_data = 1. / cfg.scale_factor * video_data | |
video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') | |
chunk_size = min(cfg.decoder_bs, video_data.shape[0]) | |
video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) | |
decode_data = [] | |
for vd_data in video_data_list: | |
gen_frames = autoencoder.decode(vd_data) | |
decode_data.append(gen_frames) | |
video_data = torch.cat(decode_data, dim=0) | |
video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() | |
text_size = cfg.resolution[-1] | |
cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_') | |
name = f'seed_{cur_seed}' | |
for ii in partial_keys_one: | |
name = name + "_" + ii | |
file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4' | |
local_path = os.path.join(cfg.log_dir, f'{file_name}') | |
os.makedirs(os.path.dirname(local_path), exist_ok=True) | |
captions = "human" | |
del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] | |
del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] | |
save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups, | |
cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps) | |
# try: | |
# save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size) | |
# logging.info('Save video to dir %s:' % (local_path)) | |
# except Exception as e: | |
# logging.info(f'Step: save text or video error with {e}') | |
logging.info('Congratulations! The inference is completed!') | |
# synchronize to finish some processes | |
if not cfg.debug: | |
torch.cuda.synchronize() | |
dist.barrier() | |
def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): | |
if use_fps_condition is True: | |
partial_keys.append('fps') | |
partial_model_kwargs = [{}, {}] | |
for partial_key in partial_keys: | |
partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] | |
partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] | |
return partial_model_kwargs | |