Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
import os | |
import sys | |
import imageio | |
import numpy as np | |
import os.path as osp | |
sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) | |
from PIL import Image, ImageDraw, ImageFont | |
import torchvision.transforms as T | |
import utils.transforms as data | |
from tools.modules.config import cfg | |
from utils.config import Config as pConfig | |
from utils.registry_class import ENGINE, DATASETS | |
from tools import * | |
def test_video_dataset(): | |
cfg_update = pConfig(load=True) | |
for k, v in cfg_update.cfg_dict.items(): | |
if isinstance(v, dict) and k in cfg: | |
cfg[k].update(v) | |
else: | |
cfg[k] = v | |
exp_name = os.path.basename(cfg.cfg_file).split('.')[0] | |
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name) | |
os.system('rm -rf %s' % (save_dir)) | |
os.makedirs(save_dir, exist_ok=True) | |
train_trans = data.Compose([ | |
data.CenterCropWide(size=cfg.resolution), | |
data.ToTensor(), | |
data.Normalize(mean=cfg.mean, std=cfg.std)]) | |
vit_trans = T.Compose([ | |
data.CenterCropWide(cfg.vit_resolution), | |
T.ToTensor(), | |
T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) | |
video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w | |
video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w | |
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w | |
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w | |
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w | |
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w | |
txt_size = cfg.resolution[1] | |
nc = int(38 * (txt_size / 256)) | |
font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13) | |
dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans) | |
print('There are %d videos' % (len(dataset))) | |
for idx, item in enumerate(dataset): | |
ref_frame, vit_frame, video_data, caption, video_key = item | |
video_data = video_data.mul_(video_std).add_(video_mean) | |
video_data.clamp_(0, 1) | |
video_data = video_data.permute(0, 2, 3, 1) | |
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data] | |
# Single Image | |
ref_frame = ref_frame.mul_(img_mean).add_(img_std) | |
ref_frame.clamp_(0, 1) | |
ref_frame = ref_frame.permute(1, 2, 0) | |
ref_frame = (ref_frame.numpy() * 255).astype('uint8') | |
# Text image | |
txt_img = Image.new("RGB", (txt_size, txt_size), color="white") | |
draw = ImageDraw.Draw(txt_img) | |
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc)) | |
draw.text((0, 0), lines, fill="black", font=font) | |
txt_img = np.array(txt_img) | |
video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data] | |
spath = os.path.join(save_dir, '%04d.gif' % (idx)) | |
imageio.mimwrite(spath, video_data, fps =8) | |
# if idx > 100: break | |
def test_vit_image(test_video_flag=True): | |
cfg_update = pConfig(load=True) | |
for k, v in cfg_update.cfg_dict.items(): | |
if isinstance(v, dict) and k in cfg: | |
cfg[k].update(v) | |
else: | |
cfg[k] = v | |
exp_name = os.path.basename(cfg.cfg_file).split('.')[0] | |
save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name) | |
os.system('rm -rf %s' % (save_dir)) | |
os.makedirs(save_dir, exist_ok=True) | |
train_trans = data.Compose([ | |
data.CenterCropWide(size=cfg.resolution), | |
data.ToTensor(), | |
data.Normalize(mean=cfg.mean, std=cfg.std)]) | |
vit_trans = data.Compose([ | |
data.CenterCropWide(cfg.resolution), | |
data.Resize(cfg.vit_resolution), | |
data.ToTensor(), | |
data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) | |
img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w | |
img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w | |
vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w | |
vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w | |
txt_size = cfg.resolution[1] | |
nc = int(38 * (txt_size / 256)) | |
font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13) | |
dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans) | |
print('There are %d videos' % (len(dataset))) | |
for idx, item in enumerate(dataset): | |
ref_frame, vit_frame, video_data, caption, video_key = item | |
video_data = video_data.mul_(img_std).add_(img_mean) | |
video_data.clamp_(0, 1) | |
video_data = video_data.permute(0, 2, 3, 1) | |
video_data = [(image.numpy() * 255).astype('uint8') for image in video_data] | |
# Single Image | |
vit_frame = vit_frame.mul_(vit_std).add_(vit_mean) | |
vit_frame.clamp_(0, 1) | |
vit_frame = vit_frame.permute(1, 2, 0) | |
vit_frame = (vit_frame.numpy() * 255).astype('uint8') | |
zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8) | |
zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame | |
# Text image | |
txt_img = Image.new("RGB", (txt_size, txt_size), color="white") | |
draw = ImageDraw.Draw(txt_img) | |
lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc)) | |
draw.text((0, 0), lines, fill="black", font=font) | |
txt_img = np.array(txt_img) | |
video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data] | |
spath = os.path.join(save_dir, '%04d.gif' % (idx)) | |
imageio.mimwrite(spath, video_data, fps =8) | |
# if idx > 100: break | |
if __name__ == '__main__': | |
# test_video_dataset() | |
test_vit_image() | |