LIA-X-fast / utils /data_processing.py
YaohuiW's picture
Upload 19 files
c42db24 verified
import os
import torch
import torchvision
from PIL import Image
import numpy as np
import imageio
from einops import rearrange, repeat
def load_image(img, size):
# img = Image.open(filename).convert('RGB')
if not isinstance(img, np.ndarray):
img = Image.open(img).convert('RGB')
img = img.resize((size, size))
img = np.asarray(img)
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
return img / 255.0
def img_preprocessing(img_path, size):
img = load_image(img_path, size) # [0, 1]
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
return imgs_norm
def resize(img, size):
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(size, antialias=True),
torchvision.transforms.CenterCrop(size)
])
return transform(img)
def vid_preprocessing(vid_path, size):
vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
fps = vid_dict[2]['video_fps']
vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
vid_norm = torch.cat([
resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
], dim=1)
return vid_norm, fps
def img_denorm(img):
img = img.clamp(-1, 1).cpu()
img = (img - img.min()) / (img.max() - img.min())
return img
def vid_denorm(vid):
vid = vid.clamp(-1, 1).cpu()
vid = (vid - vid.min()) / (vid.max() - vid.min())
return vid
def save_img_edit(save_dir, img, img_e):
# img: BCHW
# img_e: BCHW
output_img_path = os.path.join(save_dir, "img_edit.png")
output_img_all_path = os.path.join(save_dir, "img_all.png")
img = rearrange(img, 'b c h w -> b h w c')
img_e = rearrange(img_e, 'b c h w -> b h w c')
img_all = torch.cat([img, img_e], dim=2)
img_e_np = (img_denorm(img_e[0]).numpy() * 255).astype('uint8')
img_all_np = (img_denorm(img_all[0]).numpy() * 255).astype('uint8')
imageio.imwrite(output_img_path, img_e_np, quality=8)
imageio.imwrite(output_img_all_path, img_all_np, quality=8)
return
def save_vid_edit(save_dir, vid_d, vid_a, fps):
# img_s: BCHW
# vid_d: BTCHW
# vid_a: BCTHW
output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4")
output_vid_all_path = os.path.join(save_dir, "vid_all.mp4")
vid_d = rearrange(vid_d, 'b t c h w -> b t h w c')
vid_a = rearrange(vid_a, 'b c t h w -> b t h w c')
vid_all = torch.cat([vid_d, vid_a], dim=3)
vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8')
imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8)
imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8)
return
def save_animation(save_dir, img_s, vid_d, vid_a, fps):
# img_s: BCHW
# vid_d: BTCHW
# vid_a: BCTHW
output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4")
output_img_e_path = os.path.join(save_dir, "img_edit.png")
output_vid_all_path = os.path.join(save_dir, "vid_all.mp4")
vid_d = rearrange(vid_d, 'b t c h w -> b t h w c')
vid_a = rearrange(vid_a, 'b c t h w -> b t h w c')
img_s = repeat(rearrange(img_s, 'b c h w -> b h w c'), 'b h w c -> b t h w c', t=vid_d.size(1))
vid_all = torch.cat([img_s, vid_d, vid_a], dim=3)
vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
img_e_np = vid_a_np[0]
vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8')
imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8)
imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8)
imageio.imwrite(output_img_e_path, img_e_np, quality=8)
return
def save_linear_manipulation(save_dir, vid, fps):
# vid: BCTHW
output_vid_path = os.path.join(save_dir, "vid_interpolation.mp4")
vid = rearrange(vid, 'b c t h w -> b t h w c')
vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8')
imageio.mimwrite(output_vid_path, vid_np, fps=fps, codec='libx264', quality=8)
return