File size: 13,565 Bytes
06e9d12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 |
import os
import imageio
import numpy as np
from typing import Literal, Union, List, Dict, Tuple
import torch
import torchvision
import cv2
from PIL import Image
from tqdm import tqdm
from einops import rearrange
import webp
import subprocess
from .. import logger
def save_videos_to_images(videos: np.array, path: str, image_type="png") -> None:
"""save video batch to images into image_type
Args:
videos (np.array): [h w c]
path (str): image directory path
"""
os.makedirs(path, exist_ok=True)
for i, video in enumerate(videos):
imageio.imsave(os.path.join(path, f"{i:04d}.{image_type}"), video)
def save_videos_grid(
videos: torch.Tensor,
path: str,
rescale=False,
n_rows=4, # 一行多少个视频
fps=8,
save_type="webp",
) -> None:
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
if x.dtype != torch.uint8:
x = (x * 255).numpy().astype(np.uint8)
if save_type == "webp":
outputs.append(Image.fromarray(x))
else:
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
if "gif" in path or save_type == "gif":
params = {
"duration": int(1000 * 1.0 / fps),
"loop": 0,
}
elif save_type == "mp4":
params = {
"quality": 9,
"fps": fps,
"pixelformat": "yuv420p",
}
else:
params = {
"quality": 9,
"fps": fps,
}
if save_type == "webp":
webp.save_images(outputs, path, fps=fps, lossless=True)
else:
imageio.mimsave(path, outputs, **params)
def make_grid_with_opencv(
batch: Union[torch.Tensor, np.ndarray],
nrows: int,
texts: List[str] = None,
rescale: bool = False,
font_size: float = 0.05,
font_thickness: int = 1,
font_color: Tuple[int] = (255, 0, 0),
tensor_order: str = "b c h w",
write_info: bool = False,
) -> np.ndarray:
"""read tensor batch and make a grid with opencv
Args:
batch (Union[torch.Tensor, np.ndarray]): 4 dim tensor, like b c h w
nrows (int): how many rows in the grid
texts (List[str], optional): text to write in video . Defaults to None.
rescale (bool, optional): whether rescale [0,1] from [-1, 1]. Defaults to False.
font_size (float, optional): font size. Defaults to 0.05.
font_thickness (int, optional): font_thickness . Defaults to 1.
font_color (Tuple[int], optional): text color. Defaults to (255, 0, 0).
tensor_order (str, optional): batch channel order. Defaults to "b c h w".
write_info (bool, optional): whether write text into video. Defaults to True.
Returns:
np.ndarray: h w c
"""
if isinstance(batch, torch.Tensor):
batch = batch.cpu().numpy()
# batch: (B, C, H, W)
batch = rearrange(batch, f"{tensor_order} -> b h w c")
b, h, w, c = batch.shape
ncols = int(np.ceil(b / nrows))
grid = np.zeros((h * nrows, w * ncols, c), dtype=np.uint8)
font = cv2.FONT_HERSHEY_SIMPLEX
for i, x in enumerate(batch):
i_row, i_col = i // ncols, i % ncols
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).astype(np.uint8)
# 没有这行会报错
# ref: https://stackoverflow.com/questions/72327137/opencv4-5-5-error-5bad-argument-in-function-puttext
x = x.copy()
if texts is not None and write_info:
x = cv2.putText(
x,
texts[i],
(5, 20),
font,
fontScale=font_size,
color=font_color,
thickness=font_thickness,
)
grid[i_row * h : (i_row + 1) * h, i_col * w : (i_col + 1) * w, :] = x
return grid
def save_videos_grid_with_opencv(
videos: Union[torch.Tensor, np.ndarray],
path: str,
n_cols: int,
texts: List[str] = None,
rescale: bool = False,
fps: int = 8,
font_size: int = 0.6,
font_thickness: int = 1,
font_color: Tuple[int] = (255, 0, 0),
tensor_order: str = "b c t h w",
batch_dim: int = 0,
split_size_or_sections: int = None, # split batch to avoid large video
write_info: bool = False,
save_filetype: Literal["gif", "mp4", "webp"] = "mp4",
save_images: bool = False,
) -> None:
"""存储tensor视频为gif、mp4等
Args:
videos (Union[torch.Tensor, np.ndarray]): 五维视频tensor, 如 b c t h w,值范围[0-1]
path (str): 视频存储路径,后缀会影响存储方式
n_cols (int): 由于b可能特别大,所以会分成几列
texts (List[str], optional): b长度,会写在每个b视频左上角. Defaults to None.
rescale (bool, optional): 输入是[-1,1]时,应该为True. Defaults to False.
fps (int, optional): 存储视频的fps. Defaults to 8.
font_size (int, optional): text对应的字体大小. Defaults to 0.6.
font_thickness (int, optional): 字体宽度. Defaults to 1.
font_color (Tuple[int], optional): 字体颜色. Defaults to (255, 0, 0).
tensor_order (str, optional): 输入tensor的顺序,如果不是 `b c t h w`,会被转换成 b c t h w,. Defaults to "b c t h w".
batch_dim (int, optional): 有时候b特别大,这时候一个视频就太大了,就可以分成几个视频存储. Defaults to 0.
split_size_or_sections (int, optional): 不为None时,与batch_dim配套,一个存储视频最多支持几个子视频。会按照n_cols截断向上取整数. Defaults to None.
write_info (bool, False): 是否也些提示信息在视频上
"""
if split_size_or_sections is not None:
split_size_or_sections = int(np.ceil(split_size_or_sections / n_cols)) * n_cols
if isinstance(videos, np.ndarray):
videos = torch.from_numpy(videos)
# 比np.array_split更适合
videos_split = torch.split(videos, split_size_or_sections, dim=batch_dim)
videos_split = [videos.cpu().numpy() for videos in videos_split]
else:
videos_split = [videos]
n_videos_split = len(videos_split)
dirname, basename = os.path.dirname(path), os.path.basename(path)
filename, ext = os.path.splitext(basename)
os.makedirs(dirname, exist_ok=True)
for i_video, videos in enumerate(videos_split):
videos = rearrange(videos, f"{tensor_order} -> t b c h w")
outputs = []
font = cv2.FONT_HERSHEY_SIMPLEX
batch_size = videos.shape[1]
n_rows = int(np.ceil(batch_size / n_cols))
for t, x in enumerate(videos):
x = make_grid_with_opencv(
x,
n_rows,
texts,
rescale,
font_size,
font_thickness,
font_color,
write_info=write_info,
)
h, w, c = x.shape
x = x.copy()
if write_info:
x = cv2.putText(
x,
str(t),
(5, h - 20),
font,
fontScale=2,
color=font_color,
thickness=font_thickness,
)
outputs.append(x)
logger.debug(f"outputs[0].shape: {outputs[0].shape}")
# TODO: 有待更新实现方式
if i_video == 0 and n_videos_split == 1:
pass
else:
path = os.path.join(dirname, "{}_{}{}".format(filename, i_video, ext))
if save_filetype == "gif":
params = {
"duration": int(1000 * 1.0 / fps),
"loop": 0,
}
imageio.mimsave(path, outputs, **params)
elif save_filetype == "mp4":
params = {
"quality": 9,
"fps": fps,
}
imageio.mimsave(path, outputs, **params)
elif save_filetype == "webp":
outputs = [Image.fromarray(x_tmp) for x_tmp in outputs]
webp.save_images(outputs, path, fps=fps, lossless=True)
else:
raise ValueError(f"Unsupported file type: {save_filetype}")
if save_images:
images_path = os.path.join(dirname, filename)
os.makedirs(images_path, exist_ok=True)
save_videos_to_images(outputs, images_path)
def export_to_video(videos: torch.Tensor, output_video_path: str, fps=8):
tmp_path = output_video_path.replace(".mp4", "_tmp.mp4")
videos = rearrange(videos, "b c t h w -> b t h w c")
videos = videos.squeeze()
videos = (videos * 255).cpu().detach().numpy().astype(np.uint8) # tensor -> numpy
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
h, w, _ = videos[0].shape
video_writer = cv2.VideoWriter(
tmp_path, fourcc, fps=fps, frameSize=(w, h), isColor=True
)
for i in range(len(videos)):
img = cv2.cvtColor(videos[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
video_writer.release() # 要释放video writer,否则无法播放
cmd = f"ffmpeg -y -i {tmp_path} -c:v libx264 -c:a aac -strict -2 {output_video_path} -loglevel quiet"
subprocess.run(cmd, shell=True)
os.remove(tmp_path)
# DDIM Inversion
@torch.no_grad()
def init_prompt(prompt, pipeline):
uncond_input = pipeline.tokenizer(
[""],
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
return_tensors="pt",
)
uncond_embeddings = pipeline.text_encoder(
uncond_input.input_ids.to(pipeline.device)
)[0]
text_input = pipeline.tokenizer(
[prompt],
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
context = torch.cat([uncond_embeddings, text_embeddings])
return context
def next_step(
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
ddim_scheduler,
):
timestep, next_timestep = (
min(
timestep
- ddim_scheduler.config.num_train_timesteps
// ddim_scheduler.num_inference_steps,
999,
),
timestep,
)
alpha_prod_t = (
ddim_scheduler.alphas_cumprod[timestep]
if timestep >= 0
else ddim_scheduler.final_alpha_cumprod
)
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (
sample - beta_prod_t**0.5 * model_output
) / alpha_prod_t**0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = (
alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
)
return next_sample
def get_noise_pred_single(latents, t, context, unet):
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
@torch.no_grad()
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
context = init_prompt(prompt, pipeline)
uncond_embeddings, cond_embeddings = context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in tqdm(range(num_inv_steps)):
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
latent = next_step(noise_pred, t, latent, ddim_scheduler)
all_latent.append(latent)
return all_latent
@torch.no_grad()
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
ddim_latents = ddim_loop(
pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt
)
return ddim_latents
def fn_recursive_search(
name: str,
module: torch.nn.Module,
target: str,
print_method=print,
print_name: str = "data",
):
if hasattr(module, target):
print_method(
[
name + "." + target + "." + print_name,
getattr(getattr(module, target), print_name)[0].cpu().detach().numpy(),
]
)
parent_name = name
for name, child in module.named_children():
fn_recursive_search(
parent_name + "." + name, child, target, print_method, print_name
)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
)
return noise_cfg
|