Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Script to decode tokenized video into images/video. | |
| Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline | |
| """ | |
| import argparse | |
| import math | |
| import os | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import torch | |
| import torch.distributed.optim | |
| import torch.utils.checkpoint | |
| import torch.utils.data | |
| import torchvision.transforms.v2.functional as transforms_f | |
| from diffusers import AutoencoderKLTemporalDecoder | |
| from einops import rearrange | |
| from matplotlib import pyplot as plt | |
| from cont_data import RawFeatureDataset | |
| from data import RawTokenDataset | |
| from datasets.utils import get_image_encoder | |
| from magvit2.config import VQConfig | |
| from magvit2.models.lfqgan import VQModel | |
| from common.eval_utils import decode_tokens, decode_features | |
| import wandb | |
| wandb.login(key='4c1540ebf8cb9964703ac212a937c00848a79b67') | |
| SVD_SCALE = 0.18215 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Visualize tokenized video as GIF or comic.") | |
| parser.add_argument( | |
| "--stride", | |
| type=int, | |
| default=1, | |
| help="Frame skip", | |
| ) | |
| parser.add_argument( | |
| "--token_dir", | |
| type=str, | |
| default="data/genie_generated", | |
| help="Directory of tokens, in the format of `video.bin` and `metadata.json`. " | |
| "Visualized gif and comic will be written here.", | |
| ) | |
| parser.add_argument( | |
| "--offset", type=int, default=0, help="Offset to start generating images from" | |
| ) | |
| parser.add_argument( | |
| "--fps", type=int, default=2, help="Frames per second" | |
| ) | |
| parser.add_argument( | |
| "--max_images", type=int, default=None, help="Maximum number of images to generate. None for all." | |
| ) | |
| parser.add_argument( | |
| "--example_ind", type=int, default=0, | |
| help="The index in the dataset of the example to generate on." | |
| ) | |
| parser.add_argument( | |
| "--project_prefix", type=str, default="", help="Project suffix." | |
| ) | |
| parser.add_argument( | |
| "--disable_comic", action="store_true", | |
| help="Comic generation assumes `token_dir` follows the same format as generate: e.g., " | |
| "`prompt | predictions | gtruth` in `video.bin`, `window_size` in `metadata.json`." | |
| "Therefore, comic should be disabled when visualizing videos without this format, such as the dataset." | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=4, | |
| help="Batch size, current script only supports a single GPU." | |
| ) | |
| parser.add_argument( | |
| "--max_example", type=int, default=4, | |
| help="Maximum number of examples." | |
| ) | |
| parser.add_argument( | |
| "--use_feature", action="store_true", | |
| help="visualize the features rather than tokens" | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def export_to_gif(frames: list, output_gif_path: str, fps: int): | |
| """ | |
| Export a list of frames to a GIF. | |
| Args: | |
| - frames (list): List of frames (as numpy arrays or PIL Image objects). | |
| - output_gif_path (str): Path to save the output GIF. | |
| - fps (int): Desired frames per second. | |
| """ | |
| # Convert numpy arrays to PIL Images if needed | |
| pil_frames = [Image.fromarray(frame) if isinstance( | |
| frame, np.ndarray) else frame for frame in frames] | |
| duration_ms = 1000 / fps | |
| pil_frames[0].save(output_gif_path.replace(".mp4", ".gif"), | |
| format="GIF", | |
| append_images=pil_frames[1:], | |
| save_all=True, | |
| duration=duration_ms, | |
| loop=0) | |
| # return the gif | |
| return output_gif_path.replace(".mp4", ".gif") | |
| def unnormalize_imgs(normalized_imgs): | |
| """ | |
| [-1, 1] -> [0, 255] | |
| Important: clip to [0, 255] | |
| """ | |
| normalized_imgs = torch.clamp(normalized_imgs, -1, 1) | |
| rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5) | |
| clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8) | |
| return clipped_output | |
| # rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5) | |
| # clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8) | |
| # return clipped_output | |
| def decode_latents_wrapper( | |
| batch_size: int = 16, | |
| encoder_type: str = "magvit", | |
| encoder_name_or_path: str = "data/magvit2.ckpt", | |
| max_images: int = None, | |
| device: str = "cuda", | |
| ): | |
| dtype = torch.bfloat16 # torch.bfloat16 | |
| model = get_image_encoder(encoder_type, encoder_name_or_path) | |
| model = model.to(device=device, dtype=dtype) | |
| def decode_latents(video_data: np.array): | |
| """ | |
| video_data: (b, h, w) for quantized data, or (b, c, h, w) for continuous data, | |
| where b is `batch_size` and different from training/eval batch size. | |
| """ | |
| decoded_imgs = [] | |
| for shard_ind in range(math.ceil(len(video_data) / batch_size)): | |
| shard_data = video_data[shard_ind * batch_size: (shard_ind + 1) * batch_size] | |
| if isinstance(model, VQModel): # TODO: class agnostic wrapper | |
| # expecting quantized | |
| assert shard_data.ndim == 3, f"{shard_data.shape=} {shard_data.dtype=}" | |
| torch_shard = torch.from_numpy(shard_data.astype(np.int64)) | |
| # if model.use_ema: # EMA does nothing in bugged VQModel | |
| # with model.ema_scope(): | |
| quant = model.quantize.get_codebook_entry(rearrange(torch_shard, "b h w -> b (h w)"), | |
| bhwc=torch_shard.shape + (model.quantize.codebook_dim,)).flip(1) | |
| normalized_imgs = model.decode(quant.to(device=device, dtype=dtype)) | |
| elif isinstance(model, AutoencoderKLTemporalDecoder): | |
| # expecting continuous | |
| assert shard_data.ndim == 4, f"{shard_data.shape=} {shard_data.dtype=}" | |
| torch_shard = torch.from_numpy(shard_data) | |
| # manual clip | |
| # if torch_shard.shape[0] == 16: | |
| # print("prompt torch_shard", torch_shard[:4, 0].min(), torch_shard[:4, 0].max(), torch_shard[:4, 0].mean(), torch_shard[:4, 0].std()) | |
| # print("pred torch_shard", torch_shard[4:12, 0].min(), torch_shard[4:12, 0].max(), torch_shard[4:12, 0].mean(), torch_shard[4:12, 0].std()) | |
| # print("groundtruth torch_shard", torch_shard[12:, 0].min(), torch_shard[12:, 0].max(), torch_shard[12:, 0].mean(), torch_shard[12:, 0].std()) | |
| torch_shard = torch.clamp(torch_shard, -25, 25) | |
| normalized_imgs = model.decode(torch_shard.to(device=device, dtype=dtype), num_frames=1).sample # sample to mean | |
| # if torch_shard.shape[0] == 16: | |
| # print("prompt normalized_imgs", normalized_imgs[:4, 0].min(), normalized_imgs[:4, 0].max(), normalized_imgs[:4, 0].mean(), normalized_imgs[:4, 0].std()) | |
| # print("pred normalized_imgs", normalized_imgs[4:12, 0].min(), normalized_imgs[4:12, 0].max(), normalized_imgs[4:12, 0].mean(), normalized_imgs[4:12, 0].std()) | |
| # print("groundtruth normalized_imgs", normalized_imgs[12:, 0].min(), normalized_imgs[12:, 0].max(), normalized_imgs[12:, 0].mean(), normalized_imgs[12:, 0].std()) | |
| else: | |
| raise NotImplementedError(f"{model=}") | |
| decoded_imgs.append(unnormalize_imgs(normalized_imgs)) | |
| if max_images and len(decoded_imgs) * batch_size >= max_images: | |
| break | |
| return [transforms_f.to_pil_image(img) for img in torch.cat(decoded_imgs)] | |
| return decode_latents | |
| def caption_image(pil_image: Image, caption: str): | |
| """ | |
| Add a bit of empty space at the top, and add the caption there | |
| """ | |
| border_size = 36 | |
| font_size = 24 | |
| # convert pil_image to PIL.Image.Image if it's not already | |
| if not isinstance(pil_image, Image.Image): | |
| pil_image = transforms_f.to_pil_image(pil_image) | |
| width, height = pil_image.size | |
| new_width = width | |
| new_height = height + border_size | |
| new_image = Image.new("RGB", (new_width, new_height), "white") | |
| new_image.paste(pil_image, (0, border_size)) | |
| # Draw the caption | |
| draw = ImageDraw.Draw(new_image) | |
| # Center text (`align` keyword doesn't work) | |
| _, _, text_w, text_h = draw.textbbox((0, 0), caption, font_size=font_size) | |
| draw.text(((width - text_w) / 2, (border_size - text_h) / 2), caption, fill="black", font_size=font_size) | |
| return new_image | |
| def main(): | |
| args = parse_args() | |
| name = args.token_dir.split('/')[-2] | |
| name_split = name.find('nodes') | |
| model = name[:name_split-7] | |
| dataset = name[name_split+8:] | |
| # Load tokens | |
| if args.use_feature: | |
| token_dataset = RawFeatureDataset(args.token_dir, 1, compute_stride_from_freq_table=False, | |
| filter_interrupts=False, filter_overlaps=False) | |
| video_tokens = token_dataset.data | |
| print(f"Loaded {video_tokens.shape=}") | |
| else: | |
| token_dataset = RawTokenDataset(args.token_dir, 1, compute_stride_from_freq_table=False, | |
| filter_interrupts=False, filter_overlaps=False) | |
| video_tokens = token_dataset.data | |
| print(f"Loaded {video_tokens.shape=}") | |
| metadata = token_dataset.metadata | |
| video_tokens = video_tokens.reshape(-1, metadata["window_size"] * 2 - metadata["num_prompt_frames"], *video_tokens.shape[1:]) | |
| decode_func = decode_latents_wrapper | |
| print(metadata) | |
| print(f"Reshape {video_tokens.shape=}") | |
| wandb.init(project='video_eval_vis', settings=wandb.Settings(start_method="thread"), name=f"{args.project_prefix}vis_{model}", id=f"{args.project_prefix}vis_{model}", resume="allow") | |
| for example_id in range(min(args.max_example, len(video_tokens))): | |
| if args.use_feature: | |
| if "encoder_type" not in metadata: | |
| metadata["encoder_type"] = "temporalvae" | |
| metadata["encoder_name_or_path"] = "stabilityai/stable-video-diffusion-img2vid" | |
| decode_latents = decode_func(max_images=args.max_images, encoder_name_or_path=metadata["encoder_name_or_path"], | |
| encoder_type=metadata["encoder_type"]) # args.offset::args.stride | |
| this_video_token = torch.FloatTensor(video_tokens[example_id].copy())[None] / SVD_SCALE | |
| this_video_token = rearrange(this_video_token, "b t c h w -> b t h w c") | |
| video_frames = decode_features(this_video_token, decode_latents) | |
| video_frames = rearrange(video_frames, "b t c h w -> b t h w c") | |
| video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8) | |
| else: | |
| decode_latents = decode_func(max_images=args.max_images) | |
| this_video_token = torch.LongTensor(video_tokens[example_id])[None] | |
| video_frames = decode_tokens(this_video_token, decode_latents) | |
| video_frames = rearrange(video_frames, "b t c h w -> b t h w c") | |
| video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8) | |
| output_gif_path = os.path.join(args.token_dir, f"example{args.offset}.gif") | |
| # `generate` should populate `metadata.json` with these keys, while ground truth metadata does not have them | |
| is_generated_data = all(key in metadata for key in ("num_prompt_frames", "window_size")) | |
| if is_generated_data: | |
| if video_tokens[example_id].shape[0] != metadata["window_size"] * 2 - metadata["num_prompt_frames"]: | |
| raise ValueError(f"Unexpected {video_tokens.shape=} given {metadata['window_size']=}, {metadata['num_prompt_frames']=}") | |
| captioned_frames = [] | |
| for i, frame in enumerate(video_frames): | |
| if i < metadata["num_prompt_frames"]: | |
| caption = "Prompt" | |
| elif i < metadata["window_size"]: | |
| caption = "Generated" | |
| else: | |
| caption = "Ground truth" | |
| captioned_frames.append(caption_image(frame, caption)) | |
| else: | |
| # Leave ground truth frames uncaptioned | |
| captioned_frames = video_frames | |
| gif_path = export_to_gif(captioned_frames, output_gif_path, args.fps) | |
| print(f"Saved to {output_gif_path}") | |
| if not args.disable_comic: | |
| fig, axs = plt.subplots(nrows=2, ncols=metadata["window_size"], figsize=(3 * metadata["window_size"], 3 * 2)) | |
| for i, image in enumerate(video_frames): | |
| if i < metadata["num_prompt_frames"]: | |
| curr_axs = [axs[0, i], axs[1, i]] | |
| title = "Prompt" | |
| elif i < metadata["window_size"]: | |
| curr_axs = [axs[0, i]] | |
| title = "Prediction" | |
| else: | |
| curr_axs = [axs[1, i - metadata["window_size"] + metadata["num_prompt_frames"]]] | |
| title = "Ground truth" | |
| for ax in curr_axs: | |
| ax.set_title(title) | |
| ax.imshow(image) | |
| ax.axis("off") | |
| output_comic_path = os.path.join(args.token_dir, f"example{args.offset}.png") | |
| plt.savefig(output_comic_path, bbox_inches="tight") | |
| plt.close() | |
| print(f"Saved to {output_comic_path}") | |
| wandb.log({f"{dataset}/gif_{example_id}": wandb.Video(gif_path)}) | |
| # add wandb logging | |
| # wandb.log({f"{dataset}/comic_{args.example_ind}": wandb.Image(output_comic_path)}) | |
| wandb.run.summary["model_checkpoint"] = metadata["model_checkpoint"] | |
| wandb.run.summary["dataset"] = metadata["dataset"] | |
| wandb.run.summary["trained_steps"] = metadata["trained_steps"] | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() | |