|
import os |
|
import imageio |
|
import numpy as np |
|
from typing import Literal, Union, List, Dict, Tuple |
|
|
|
import torch |
|
import torchvision |
|
import cv2 |
|
from PIL import Image |
|
|
|
from tqdm import tqdm |
|
from einops import rearrange |
|
import webp |
|
import subprocess |
|
|
|
from .. import logger |
|
|
|
|
|
def save_videos_to_images(videos: np.array, path: str, image_type="png") -> None: |
|
"""save video batch to images into image_type |
|
|
|
Args: |
|
videos (np.array): [h w c] |
|
path (str): image directory path |
|
""" |
|
os.makedirs(path, exist_ok=True) |
|
for i, video in enumerate(videos): |
|
imageio.imsave(os.path.join(path, f"{i:04d}.{image_type}"), video) |
|
|
|
|
|
def save_videos_grid( |
|
videos: torch.Tensor, |
|
path: str, |
|
rescale=False, |
|
n_rows=4, |
|
fps=8, |
|
save_type="webp", |
|
) -> None: |
|
videos = rearrange(videos, "b c t h w -> t b c h w") |
|
outputs = [] |
|
for x in videos: |
|
x = torchvision.utils.make_grid(x, nrow=n_rows) |
|
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) |
|
if rescale: |
|
x = (x + 1.0) / 2.0 |
|
if x.dtype != torch.uint8: |
|
x = (x * 255).numpy().astype(np.uint8) |
|
|
|
if save_type == "webp": |
|
outputs.append(Image.fromarray(x)) |
|
else: |
|
outputs.append(x) |
|
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
if "gif" in path or save_type == "gif": |
|
params = { |
|
"duration": int(1000 * 1.0 / fps), |
|
"loop": 0, |
|
} |
|
elif save_type == "mp4": |
|
params = { |
|
"quality": 9, |
|
"fps": fps, |
|
"pixelformat": "yuv420p", |
|
} |
|
else: |
|
params = { |
|
"quality": 9, |
|
"fps": fps, |
|
} |
|
|
|
if save_type == "webp": |
|
webp.save_images(outputs, path, fps=fps, lossless=True) |
|
else: |
|
imageio.mimsave(path, outputs, **params) |
|
|
|
|
|
def make_grid_with_opencv( |
|
batch: Union[torch.Tensor, np.ndarray], |
|
nrows: int, |
|
texts: List[str] = None, |
|
rescale: bool = False, |
|
font_size: float = 0.05, |
|
font_thickness: int = 1, |
|
font_color: Tuple[int] = (255, 0, 0), |
|
tensor_order: str = "b c h w", |
|
write_info: bool = False, |
|
) -> np.ndarray: |
|
"""read tensor batch and make a grid with opencv |
|
|
|
Args: |
|
batch (Union[torch.Tensor, np.ndarray]): 4 dim tensor, like b c h w |
|
nrows (int): how many rows in the grid |
|
texts (List[str], optional): text to write in video . Defaults to None. |
|
rescale (bool, optional): whether rescale [0,1] from [-1, 1]. Defaults to False. |
|
font_size (float, optional): font size. Defaults to 0.05. |
|
font_thickness (int, optional): font_thickness . Defaults to 1. |
|
font_color (Tuple[int], optional): text color. Defaults to (255, 0, 0). |
|
tensor_order (str, optional): batch channel order. Defaults to "b c h w". |
|
write_info (bool, optional): whether write text into video. Defaults to True. |
|
|
|
Returns: |
|
np.ndarray: h w c |
|
""" |
|
if isinstance(batch, torch.Tensor): |
|
batch = batch.cpu().numpy() |
|
|
|
batch = rearrange(batch, f"{tensor_order} -> b h w c") |
|
b, h, w, c = batch.shape |
|
ncols = int(np.ceil(b / nrows)) |
|
grid = np.zeros((h * nrows, w * ncols, c), dtype=np.uint8) |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
for i, x in enumerate(batch): |
|
i_row, i_col = i // ncols, i % ncols |
|
if rescale: |
|
x = (x + 1.0) / 2.0 |
|
x = (x * 255).astype(np.uint8) |
|
|
|
|
|
x = x.copy() |
|
if texts is not None and write_info: |
|
x = cv2.putText( |
|
x, |
|
texts[i], |
|
(5, 20), |
|
font, |
|
fontScale=font_size, |
|
color=font_color, |
|
thickness=font_thickness, |
|
) |
|
grid[i_row * h : (i_row + 1) * h, i_col * w : (i_col + 1) * w, :] = x |
|
return grid |
|
|
|
|
|
def save_videos_grid_with_opencv( |
|
videos: Union[torch.Tensor, np.ndarray], |
|
path: str, |
|
n_cols: int, |
|
texts: List[str] = None, |
|
rescale: bool = False, |
|
fps: int = 8, |
|
font_size: int = 0.6, |
|
font_thickness: int = 1, |
|
font_color: Tuple[int] = (255, 0, 0), |
|
tensor_order: str = "b c t h w", |
|
batch_dim: int = 0, |
|
split_size_or_sections: int = None, |
|
write_info: bool = False, |
|
save_filetype: Literal["gif", "mp4", "webp"] = "mp4", |
|
save_images: bool = False, |
|
) -> None: |
|
"""存储tensor视频为gif、mp4等 |
|
|
|
Args: |
|
videos (Union[torch.Tensor, np.ndarray]): 五维视频tensor, 如 b c t h w,值范围[0-1] |
|
path (str): 视频存储路径,后缀会影响存储方式 |
|
n_cols (int): 由于b可能特别大,所以会分成几列 |
|
texts (List[str], optional): b长度,会写在每个b视频左上角. Defaults to None. |
|
rescale (bool, optional): 输入是[-1,1]时,应该为True. Defaults to False. |
|
fps (int, optional): 存储视频的fps. Defaults to 8. |
|
font_size (int, optional): text对应的字体大小. Defaults to 0.6. |
|
font_thickness (int, optional): 字体宽度. Defaults to 1. |
|
font_color (Tuple[int], optional): 字体颜色. Defaults to (255, 0, 0). |
|
tensor_order (str, optional): 输入tensor的顺序,如果不是 `b c t h w`,会被转换成 b c t h w,. Defaults to "b c t h w". |
|
batch_dim (int, optional): 有时候b特别大,这时候一个视频就太大了,就可以分成几个视频存储. Defaults to 0. |
|
split_size_or_sections (int, optional): 不为None时,与batch_dim配套,一个存储视频最多支持几个子视频。会按照n_cols截断向上取整数. Defaults to None. |
|
write_info (bool, False): 是否也些提示信息在视频上 |
|
""" |
|
if split_size_or_sections is not None: |
|
split_size_or_sections = int(np.ceil(split_size_or_sections / n_cols)) * n_cols |
|
if isinstance(videos, np.ndarray): |
|
videos = torch.from_numpy(videos) |
|
|
|
videos_split = torch.split(videos, split_size_or_sections, dim=batch_dim) |
|
videos_split = [videos.cpu().numpy() for videos in videos_split] |
|
else: |
|
videos_split = [videos] |
|
n_videos_split = len(videos_split) |
|
dirname, basename = os.path.dirname(path), os.path.basename(path) |
|
filename, ext = os.path.splitext(basename) |
|
os.makedirs(dirname, exist_ok=True) |
|
|
|
for i_video, videos in enumerate(videos_split): |
|
videos = rearrange(videos, f"{tensor_order} -> t b c h w") |
|
outputs = [] |
|
font = cv2.FONT_HERSHEY_SIMPLEX |
|
batch_size = videos.shape[1] |
|
n_rows = int(np.ceil(batch_size / n_cols)) |
|
for t, x in enumerate(videos): |
|
x = make_grid_with_opencv( |
|
x, |
|
n_rows, |
|
texts, |
|
rescale, |
|
font_size, |
|
font_thickness, |
|
font_color, |
|
write_info=write_info, |
|
) |
|
h, w, c = x.shape |
|
x = x.copy() |
|
if write_info: |
|
x = cv2.putText( |
|
x, |
|
str(t), |
|
(5, h - 20), |
|
font, |
|
fontScale=2, |
|
color=font_color, |
|
thickness=font_thickness, |
|
) |
|
outputs.append(x) |
|
logger.debug(f"outputs[0].shape: {outputs[0].shape}") |
|
|
|
if i_video == 0 and n_videos_split == 1: |
|
pass |
|
else: |
|
path = os.path.join(dirname, "{}_{}{}".format(filename, i_video, ext)) |
|
if save_filetype == "gif": |
|
params = { |
|
"duration": int(1000 * 1.0 / fps), |
|
"loop": 0, |
|
} |
|
imageio.mimsave(path, outputs, **params) |
|
elif save_filetype == "mp4": |
|
params = { |
|
"quality": 9, |
|
"fps": fps, |
|
} |
|
imageio.mimsave(path, outputs, **params) |
|
elif save_filetype == "webp": |
|
outputs = [Image.fromarray(x_tmp) for x_tmp in outputs] |
|
webp.save_images(outputs, path, fps=fps, lossless=True) |
|
else: |
|
raise ValueError(f"Unsupported file type: {save_filetype}") |
|
if save_images: |
|
images_path = os.path.join(dirname, filename) |
|
os.makedirs(images_path, exist_ok=True) |
|
save_videos_to_images(outputs, images_path) |
|
|
|
|
|
def export_to_video(videos: torch.Tensor, output_video_path: str, fps=8): |
|
tmp_path = output_video_path.replace(".mp4", "_tmp.mp4") |
|
|
|
videos = rearrange(videos, "b c t h w -> b t h w c") |
|
videos = videos.squeeze() |
|
videos = (videos * 255).cpu().detach().numpy().astype(np.uint8) |
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
h, w, _ = videos[0].shape |
|
video_writer = cv2.VideoWriter( |
|
tmp_path, fourcc, fps=fps, frameSize=(w, h), isColor=True |
|
) |
|
for i in range(len(videos)): |
|
img = cv2.cvtColor(videos[i], cv2.COLOR_RGB2BGR) |
|
video_writer.write(img) |
|
video_writer.release() |
|
cmd = f"ffmpeg -y -i {tmp_path} -c:v libx264 -c:a aac -strict -2 {output_video_path} -loglevel quiet" |
|
subprocess.run(cmd, shell=True) |
|
os.remove(tmp_path) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def init_prompt(prompt, pipeline): |
|
uncond_input = pipeline.tokenizer( |
|
[""], |
|
padding="max_length", |
|
max_length=pipeline.tokenizer.model_max_length, |
|
return_tensors="pt", |
|
) |
|
uncond_embeddings = pipeline.text_encoder( |
|
uncond_input.input_ids.to(pipeline.device) |
|
)[0] |
|
text_input = pipeline.tokenizer( |
|
[prompt], |
|
padding="max_length", |
|
max_length=pipeline.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] |
|
context = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
return context |
|
|
|
|
|
def next_step( |
|
model_output: Union[torch.FloatTensor, np.ndarray], |
|
timestep: int, |
|
sample: Union[torch.FloatTensor, np.ndarray], |
|
ddim_scheduler, |
|
): |
|
timestep, next_timestep = ( |
|
min( |
|
timestep |
|
- ddim_scheduler.config.num_train_timesteps |
|
// ddim_scheduler.num_inference_steps, |
|
999, |
|
), |
|
timestep, |
|
) |
|
alpha_prod_t = ( |
|
ddim_scheduler.alphas_cumprod[timestep] |
|
if timestep >= 0 |
|
else ddim_scheduler.final_alpha_cumprod |
|
) |
|
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] |
|
beta_prod_t = 1 - alpha_prod_t |
|
next_original_sample = ( |
|
sample - beta_prod_t**0.5 * model_output |
|
) / alpha_prod_t**0.5 |
|
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output |
|
next_sample = ( |
|
alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction |
|
) |
|
return next_sample |
|
|
|
|
|
def get_noise_pred_single(latents, t, context, unet): |
|
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] |
|
return noise_pred |
|
|
|
|
|
@torch.no_grad() |
|
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): |
|
context = init_prompt(prompt, pipeline) |
|
uncond_embeddings, cond_embeddings = context.chunk(2) |
|
all_latent = [latent] |
|
latent = latent.clone().detach() |
|
for i in tqdm(range(num_inv_steps)): |
|
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] |
|
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) |
|
latent = next_step(noise_pred, t, latent, ddim_scheduler) |
|
all_latent.append(latent) |
|
return all_latent |
|
|
|
|
|
@torch.no_grad() |
|
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): |
|
ddim_latents = ddim_loop( |
|
pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt |
|
) |
|
return ddim_latents |
|
|
|
|
|
def fn_recursive_search( |
|
name: str, |
|
module: torch.nn.Module, |
|
target: str, |
|
print_method=print, |
|
print_name: str = "data", |
|
): |
|
if hasattr(module, target): |
|
print_method( |
|
[ |
|
name + "." + target + "." + print_name, |
|
getattr(getattr(module, target), print_name)[0].cpu().detach().numpy(), |
|
] |
|
) |
|
|
|
parent_name = name |
|
for name, child in module.named_children(): |
|
fn_recursive_search( |
|
parent_name + "." + name, child, target, print_method, print_name |
|
) |
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std( |
|
dim=list(range(1, noise_pred_text.ndim)), keepdim=True |
|
) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = ( |
|
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
) |
|
return noise_cfg |
|
|