Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import torch | |
import torchvision | |
from PIL import Image | |
import numpy as np | |
import imageio | |
from einops import rearrange, repeat | |
def load_image(img, size): | |
# img = Image.open(filename).convert('RGB') | |
if not isinstance(img, np.ndarray): | |
img = Image.open(img).convert('RGB') | |
img = img.resize((size, size)) | |
img = np.asarray(img) | |
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256 | |
return img / 255.0 | |
def img_preprocessing(img_path, size): | |
img = load_image(img_path, size) # [0, 1] | |
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1] | |
imgs_norm = (img - 0.5) * 2.0 # [-1, 1] | |
return imgs_norm | |
def resize(img, size): | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(size, antialias=True), | |
torchvision.transforms.CenterCrop(size) | |
]) | |
return transform(img) | |
def vid_preprocessing(vid_path, size): | |
vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec') | |
vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw | |
fps = vid_dict[2]['video_fps'] | |
vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1] | |
vid_norm = torch.cat([ | |
resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1)) | |
], dim=1) | |
return vid_norm, fps | |
def img_denorm(img): | |
img = img.clamp(-1, 1).cpu() | |
img = (img - img.min()) / (img.max() - img.min()) | |
return img | |
def vid_denorm(vid): | |
vid = vid.clamp(-1, 1).cpu() | |
vid = (vid - vid.min()) / (vid.max() - vid.min()) | |
return vid | |
def save_img_edit(save_dir, img, img_e): | |
# img: BCHW | |
# img_e: BCHW | |
output_img_path = os.path.join(save_dir, "img_edit.png") | |
output_img_all_path = os.path.join(save_dir, "img_all.png") | |
img = rearrange(img, 'b c h w -> b h w c') | |
img_e = rearrange(img_e, 'b c h w -> b h w c') | |
img_all = torch.cat([img, img_e], dim=2) | |
img_e_np = (img_denorm(img_e[0]).numpy() * 255).astype('uint8') | |
img_all_np = (img_denorm(img_all[0]).numpy() * 255).astype('uint8') | |
imageio.imwrite(output_img_path, img_e_np, quality=8) | |
imageio.imwrite(output_img_all_path, img_all_np, quality=8) | |
return | |
def save_vid_edit(save_dir, vid_d, vid_a, fps): | |
# img_s: BCHW | |
# vid_d: BTCHW | |
# vid_a: BCTHW | |
output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4") | |
output_vid_all_path = os.path.join(save_dir, "vid_all.mp4") | |
vid_d = rearrange(vid_d, 'b t c h w -> b t h w c') | |
vid_a = rearrange(vid_a, 'b c t h w -> b t h w c') | |
vid_all = torch.cat([vid_d, vid_a], dim=3) | |
vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8') | |
vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8') | |
imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8) | |
imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8) | |
return | |
def save_animation(save_dir, img_s, vid_d, vid_a, fps): | |
# img_s: BCHW | |
# vid_d: BTCHW | |
# vid_a: BCTHW | |
output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4") | |
output_img_e_path = os.path.join(save_dir, "img_edit.png") | |
output_vid_all_path = os.path.join(save_dir, "vid_all.mp4") | |
vid_d = rearrange(vid_d, 'b t c h w -> b t h w c') | |
vid_a = rearrange(vid_a, 'b c t h w -> b t h w c') | |
img_s = repeat(rearrange(img_s, 'b c h w -> b h w c'), 'b h w c -> b t h w c', t=vid_d.size(1)) | |
vid_all = torch.cat([img_s, vid_d, vid_a], dim=3) | |
vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8') | |
img_e_np = vid_a_np[0] | |
vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8') | |
imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8) | |
imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8) | |
imageio.imwrite(output_img_e_path, img_e_np, quality=8) | |
return | |
def save_linear_manipulation(save_dir, vid, fps): | |
# vid: BCTHW | |
output_vid_path = os.path.join(save_dir, "vid_interpolation.mp4") | |
vid = rearrange(vid, 'b c t h w -> b t h w c') | |
vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8') | |
imageio.mimwrite(output_vid_path, vid_np, fps=fps, codec='libx264', quality=8) | |
return |