Spaces:
Runtime error
Runtime error
| """ | |
| Example usage: | |
| `python genie/evaluate.py --checkpoint_dir 1x-technologies/GENIE_35M` | |
| """ | |
| import argparse | |
| import time | |
| import os | |
| import sys | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import accelerate | |
| import wandb | |
| import lpips | |
| import torch | |
| import transformers | |
| from accelerate import DataLoaderConfiguration | |
| from einops import rearrange | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import default_data_collator | |
| import numpy as np | |
| sys.path.append(os.getcwd()) | |
| import re | |
| from data import RawTokenDataset | |
| from visualize import decode_latents_wrapper | |
| from genie.st_mask_git import STMaskGIT | |
| from skimage import metrics as image_metrics | |
| from cont_data import RawFeatureDataset | |
| from raw_image_data import RawImageDataset | |
| from genie.st_mar import STMAR | |
| from datasets import utils | |
| from common.fid_score import calculate_fid | |
| from common.calculate_fvd import calculate_fvd | |
| from common.eval_utils import decode_tokens, decode_features, compute_lpips, AvgMetric, compute_loss | |
| wandb.login(key='4c1540ebf8cb9964703ac212a937c00848a79b67') | |
| # Hardcoded values for the v1.1 dataset | |
| WINDOW_SIZE = 12 | |
| STRIDE = 15 # Data is 30 Hz so with stride 15, video is 2 Hz | |
| SVD_SCALE = 0.18215 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Evaluate GENIE-style models.") | |
| parser.add_argument( | |
| "--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj10_val", | |
| help="A directory with video data, should have a `metadata.json` and `video.bin`." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_dir", type=str, | |
| help="Path to a HuggingFace-style checkpoint." | |
| ) | |
| parser.add_argument( | |
| "--batch_size", type=int, default=4, | |
| help="Batch size, current script only supports a single GPU." | |
| ) | |
| parser.add_argument( | |
| "--maskgit_steps", type=int, default=4, help="Number of MaskGIT sampling steps." | |
| ) | |
| parser.add_argument( | |
| "--temperature", type=float, default=0, | |
| help="Sampling temperature. If `temperature` <= 1e-8, will do greedy sampling." | |
| ) | |
| parser.add_argument( | |
| "--save_outputs_dir", type=str, | |
| help="Debug option. If specified, will save model predictions and ground truths to this directory. " | |
| "Specifically, will save `{pred_frames,pred_logits,gtruth_frames,gtruth_tokens}.pt`" | |
| ) | |
| parser.add_argument( | |
| "--max_examples", type=int, default=200, | |
| help="If specified, will stop evaluation early after `max_examples` examples." | |
| ) | |
| parser.add_argument( | |
| "--autoregressive_time", action="store_true", | |
| help="If True, autoregressive generation in time dimension." | |
| ) | |
| parser.add_argument( | |
| "--add_action_input", action="store_true", | |
| help="If True, uses action in the video output." | |
| ) | |
| parser.add_argument( | |
| "--perturbation_type", type=str, default="gaussian", | |
| help="Type of perturbation to apply to the action input. Options: gaussian " | |
| ) | |
| parser.add_argument( | |
| "--perturbation_scale", type=float, default=0.1, | |
| help="Perturbation applied to each action dimension." | |
| ) | |
| parser.add_argument( | |
| "--project_prefix", type=str, default="", help="Project suffix." | |
| ) | |
| parser.add_argument( | |
| "--use_feature", action="store_true", | |
| help="visualize the features rather than tokens" | |
| ) | |
| parser.add_argument( | |
| "--use_raw_image", action="store_true", | |
| help="use raw images as inputs", | |
| default=True | |
| ) | |
| return parser.parse_args() | |
| def get_model_step(checkpoint_dir): | |
| if os.path.exists(f"{checkpoint_dir}/scheduler.bin"): | |
| sch = torch.load(f"{checkpoint_dir}/scheduler.bin") | |
| return sch['_step_count'] | |
| return 0 | |
| class GenieEvaluator: | |
| def __init__(self, args, decode_latents, device="cuda"): | |
| super().__init__() | |
| if not os.path.exists(args.checkpoint_dir + "/config.json"): | |
| # search and find the latest modified checkpoint folder | |
| dirs = [os.path.join(args.checkpoint_dir, f.name) for f in os.scandir(args.checkpoint_dir) if f.is_dir()] | |
| dirs.sort(key=os.path.getctime) | |
| if len(dirs) > 3 and os.path.join(args.checkpoint_dir, "epoch_1") in dirs: | |
| dirs.remove(os.path.join(args.checkpoint_dir, "epoch_1")) | |
| if len(dirs) == 0: | |
| exit(f"No checkpoint found in {args.checkpoint_dir}") | |
| paths = dirs[:-3] | |
| # only keep the last 3 | |
| for path in paths: | |
| print(f"evaluation: remove rm -rf {path}") | |
| os.system(f"rm -rf {path}") | |
| args.checkpoint_dir = dirs[-1] | |
| print("Loading model from:", args.checkpoint_dir) | |
| self.model = STMAR.from_pretrained(args.checkpoint_dir) | |
| self.model_step = get_model_step(args.checkpoint_dir) | |
| self.model = self.model.to(device=device) | |
| self.model.eval() | |
| self.decode_latents = decode_latents | |
| self.device = device | |
| self.args = args | |
| def predict_zframe_logits(self, input_ids: torch.Tensor, action_ids: torch.Tensor = None, domains = None, | |
| skip_normalization: bool = False) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
| """ | |
| Conditioned on each prefix: [frame_0], [frame_0, frame_1], ..., [frame_0, frame_1, ... frame_{T-1}], | |
| predict the tokens in the following frame: [pred_frame_1, pred_frame_2, ..., pred_frame_T]. | |
| Image logits are denoised in parallel across spatial dimension and teacher-forced | |
| across the time dimension. To compute logits, we save both the samples and logits as we do MaskGIT generation. | |
| Total number of forward passes is (T-1) * maskgit steps. | |
| Args: | |
| input_ids: Tensor of size (B, T*H*W) corresponding to flattened, tokenized images. | |
| Returns: (samples_THW, factored_logits) | |
| samples_THW: | |
| size (B, T, H, W) corresponding to the token ids of the predicted frames. | |
| May differ from the argmax of `factored_logits` if not greedy sampling. | |
| factored_logits: | |
| size (B, 512, 2, T-1, H, W) corresponding to the predicted logits. | |
| Note that we are factorizing the 2**18 vocabulary into two separate vocabularies of size 512 each. | |
| """ | |
| inputs_THW = rearrange(input_ids, "b (t h w) ... -> b t h w ...", t=WINDOW_SIZE, | |
| h=self.args.latent_h, w=self.args.latent_w).to(self.device) | |
| all_samples = [] | |
| all_logits = [] | |
| samples_HW = inputs_THW.clone() | |
| for timestep in range(1, WINDOW_SIZE): | |
| print(f"Generating frame {timestep}") | |
| inputs_masked = inputs_THW.clone() | |
| if self.args.autoregressive_time: | |
| if timestep > self.model.config.num_prompt_frames: | |
| inputs_masked[:, timestep-1] = samples_HW.clone() | |
| inputs_masked[:, timestep:] = self.model.mask_token | |
| # MaskGIT sampling | |
| samples_HW, factored_logits, _ = self.model.maskgit_generate( | |
| inputs_masked, out_t=timestep, maskgit_steps=self.args.maskgit_steps, | |
| temperature=self.args.temperature, action_ids=action_ids, domain=domains, | |
| skip_normalization=skip_normalization | |
| ) | |
| all_samples.append(samples_HW) | |
| all_logits.append(factored_logits) | |
| samples_THW = torch.stack(all_samples, dim=1) | |
| return samples_THW, torch.stack(all_logits, dim=3) | |
| def predict_next_frames(self, samples_THW) -> torch.Tensor: | |
| """ | |
| All model submissions should have this defined. | |
| Like predict_next_frames, this is teacher-forced along spatial dimension, autoregressive along time dimension. | |
| Conditioned on each prefix: [frame_0], [frame_0, frame_1], ..., [frame_0, frame_1, ..., frame_{T-1}], | |
| predict the following frame: [pred_frame_1, pred_frame_2, ..., pred_frame_T]. | |
| For this model, the frames are generated by using the argmax of `predict_zframe_logits` | |
| and decoding the quantized latent space tokens back to the original image space. | |
| Args: | |
| samples_THW: LongTensor of size (B, T, H, W) corresponding to sampled images in the quantized latent space. | |
| Returns: | |
| LongTensor of size (B, T-1, 3, 256, 256) corresponding to the predicted frames. | |
| """ | |
| return decode_features(samples_THW.cpu() / SVD_SCALE, self.decode_latents) | |
| def main(): | |
| transformers.set_seed(42) | |
| args = parse_args() | |
| # allow different batch sizes in final batch | |
| accelerator = accelerate.Accelerator(dataloader_config=DataLoaderConfiguration(even_batches=False)) | |
| # if "robomimic" in args.val_data_dir: | |
| # dataset = "robomimic" | |
| # save the results to wandb. hardcoded the input dataset to have magvit and will change later | |
| dataset = re.search(r"data/(.*?)_magvit", args.val_data_dir).group(1) | |
| # rtrim the last / and get the last part of the path | |
| args.checkpoint_dir = args.checkpoint_dir.rstrip('/') | |
| name = args.checkpoint_dir.split('/')[-1] | |
| decode_latents = decode_latents_wrapper(device=accelerator.device, encoder_name_or_path="stabilityai/stable-video-diffusion-img2vid", | |
| encoder_type="temporalvae") | |
| evaluator = GenieEvaluator(args, decode_latents) | |
| action_d = len(evaluator.model.action_preprocessor[dataset].mean) | |
| action_d_horizon = evaluator.model.config.d_actions[evaluator.model.config.action_domains.index(dataset)] | |
| stride = action_d_horizon // action_d | |
| print("model stride:", stride) | |
| if accelerator.is_main_process: | |
| wandb.teardown() | |
| wandb.init(project='video_val', resume="allow", id=f"{args.project_prefix}{name}", name=f"{args.project_prefix}{name}", settings=wandb.Settings(start_method="thread")) | |
| with_action_input = True | |
| if args.use_raw_image: | |
| args.val_data_dir = args.val_data_dir.replace("magvit", "image") | |
| val_dataset = RawImageDataset(args.val_data_dir, window_size=WINDOW_SIZE, compute_stride_from_freq_table=False, | |
| stride=stride, filter_overlaps=True, | |
| use_actions=with_action_input) | |
| else: | |
| # args.val_data_dir = args.val_data_dir.replace("magvit", "vae") | |
| args.val_data_dir = args.val_data_dir.replace("magvit_traj1000000", "noquant_temporalvae_shard0_of_1") | |
| val_dataset = RawFeatureDataset(args.val_data_dir, window_size=WINDOW_SIZE, compute_stride_from_freq_table=False, | |
| stride=stride, filter_overlaps=True, | |
| use_actions=with_action_input) | |
| dataset_metadata = val_dataset.metadata | |
| assert hasattr(evaluator, "model"), "Expected Evaluator to have attribute `model`." | |
| evaluator.model = accelerator.prepare_model(evaluator.model, evaluation_mode=True) # No DDP | |
| with_action_input = evaluator.model.config.use_actions # hack to reset | |
| lpips_alex = lpips.LPIPS(net="alex") # Calculate LPIPS w/ AlexNet, which is the fastest model out of their options | |
| random_samples = None | |
| if args.max_examples is not None: | |
| val_dataset.valid_start_inds = val_dataset.valid_start_inds[:args.max_examples] | |
| dataloader = DataLoader(val_dataset, collate_fn=default_data_collator, batch_size=args.batch_size) | |
| metrics = defaultdict(AvgMetric) | |
| batch_idx = 0 | |
| latent_side_len = 32 # hardcoded | |
| args.latent_h = args.latent_w = latent_side_len | |
| dataloader = accelerator.prepare(dataloader) | |
| gt_full_sequence = [] | |
| generated_full_sequence = [] | |
| for batch in tqdm(dataloader): | |
| batch_idx += 1 | |
| if args.use_raw_image: | |
| # token the batches on the fly | |
| images = batch["images"].detach().cpu().numpy().astype(np.uint8) | |
| outputs = [] | |
| for context in images: | |
| output = [] | |
| for image_t in context: | |
| output_t = utils.get_vae_image_embeddings( | |
| image_t, | |
| encoder_type="temporalvae", | |
| encoder_name_or_path="stabilityai/stable-video-diffusion-img2vid", | |
| ) | |
| output.append(output_t) | |
| outputs.append(output) | |
| batch["input_ids"] = torch.FloatTensor(outputs).to(evaluator.device) | |
| batch["input_ids"] = rearrange(batch["input_ids"], "b t c h w -> b (t h w) c") * SVD_SCALE | |
| batch["labels"] = batch["input_ids"].clone() | |
| batch_size = batch["input_ids"].size(0) | |
| reshaped_input_ids = rearrange(batch["input_ids"], "b (t h w) ... -> b t h w ...", t=WINDOW_SIZE, | |
| h=latent_side_len, w=latent_side_len) | |
| start_time = time.time() | |
| if not with_action_input: | |
| samples, _ = evaluator.predict_zframe_logits(batch["input_ids"].to(evaluator.device), domains=[val_dataset.name]) | |
| else: | |
| samples, _ = evaluator.predict_zframe_logits(batch["input_ids"].to(evaluator.device), | |
| batch["action_ids"].to(evaluator.device), [val_dataset.name]) | |
| frames_per_batch = (WINDOW_SIZE - 1) * batch["input_ids"].size(0) | |
| metrics["gen_time"].update((time.time() - start_time) / frames_per_batch, batch_size) | |
| start_time = time.time() | |
| pred_frames = evaluator.predict_next_frames(samples) | |
| metrics["dec_time"].update((time.time() - start_time) / frames_per_batch, batch_size) | |
| decoded_gtruth = decode_features(reshaped_input_ids / SVD_SCALE, decode_latents) | |
| decoded_gtruth_clone = batch['images'].permute(0, 1, 4, 2, 3)[:len(decoded_gtruth)] | |
| if args.use_raw_image: # key: use raw image as the groundtruth | |
| decoded_gtruth = batch['images'].permute(0, 1, 4, 2, 3)[:len(decoded_gtruth)].long().cpu().detach() | |
| metrics["pred_lpips"].update_list(compute_lpips(decoded_gtruth[:, 1:], pred_frames, lpips_alex)) | |
| gt_frames_numpy = decoded_gtruth[:, 1:].detach().cpu().numpy() | |
| pred_frames_numpy = pred_frames.detach().cpu().numpy() | |
| # save the image to wandb | |
| # if accelerator.is_main_process: | |
| # for i in range(gt_frames_numpy.shape[0] // 4): | |
| # wandb.log({ | |
| # f"{dataset}/gt_{i}": [wandb.Image(np.transpose(gt_frames_numpy[i][j], (1,2,0))) for j in range(gt_frames_numpy.shape[1] // 2)], | |
| # f"{dataset}/pred_{i}": [wandb.Image(np.transpose(pred_frames_numpy[i][j], (1,2,0))) for j in range(pred_frames_numpy.shape[1] // 2)] | |
| # }) | |
| psnr = [image_metrics.peak_signal_noise_ratio( | |
| gt_frames_numpy[i][-1] / 255., pred_frames_numpy[i][-1] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])] | |
| ssim = [np.mean([image_metrics.structural_similarity( | |
| gt_frames_numpy[i][j] / 255., pred_frames_numpy[i][j] / 255., data_range=1.0, channel_axis=0) \ | |
| for i in range(gt_frames_numpy.shape[0])]) for j in range(gt_frames_numpy.shape[1])] | |
| metrics["ssim"].update_list(ssim) | |
| metrics["psnr"].update_list(psnr) | |
| gt_full_sequence.append(decoded_gtruth_clone[:, 1:]) | |
| generated_full_sequence.append(pred_frames) | |
| # metrics["fvd"].update_list(calculate_fvd(.float().to(accelerator.device) / 255., | |
| # pred_frames.float().to(accelerator.device) / 255, device=accelerator.device)) | |
| # try: | |
| # metrics["fvd"].update_list(calculate_fvd(decoded_gtruth[:, 1:], pred_frames)) | |
| # except Exception as e: | |
| # print(f"Error calculating FVD: {e}") | |
| # As in Genie. we also compute psnr_delta = PSNR(x_t, x_t_hat) - PSNR(x_t, x_t_hatprime) where x_t_hatprime samples random actions | |
| # this difference in PSNR measures the controllability | |
| # actions need to be just uniform random actions | |
| if with_action_input: | |
| # for computing delta psnr | |
| N_TRIALS = 5 | |
| psnr_delta_mean = np.zeros(gt_frames_numpy.shape[0]) | |
| for _ in range(N_TRIALS): | |
| # action_mean, action_std = val_dataset.action_stat | |
| # action_std = torch.tensor(action_std).to(evaluator.device) | |
| # action_mean = torch.tensor(action_mean).to(evaluator.device) | |
| action_mean = evaluator.model.action_preprocessor[dataset].mean.repeat(stride) | |
| action_std = evaluator.model.action_preprocessor[dataset].std.repeat(stride) | |
| random_action_ids = torch.randn_like(batch["action_ids"]) * action_std + action_mean | |
| random_samples, _ = evaluator.predict_zframe_logits(batch["input_ids"].to(evaluator.device), | |
| random_action_ids.to(evaluator.device), [val_dataset.name], | |
| skip_normalization=False) | |
| random_pred_frames = evaluator.predict_next_frames(random_samples) | |
| random_pred_frames_numpy = random_pred_frames.detach().cpu().numpy() | |
| # random subtracts groundtruth | |
| psnr_delta = [psnr[i] - image_metrics.peak_signal_noise_ratio( | |
| gt_frames_numpy[i][-1] / 255., random_pred_frames_numpy[i][-1] / 255., data_range=1.0) for i in range(gt_frames_numpy.shape[0])] | |
| psnr_delta_mean += np.array(psnr_delta) / N_TRIALS | |
| metrics[f"psnr_delta"].update_list(psnr_delta_mean) | |
| print(f"=== dataset {dataset} model: {name}") | |
| print({key: f"{val.mean():.4f}" for key, val in metrics.items()}) | |
| if batch_idx > args.max_examples: | |
| break | |
| generated_full_sequence = torch.cat(generated_full_sequence, dim=0) / 255. | |
| gt_full_sequence = torch.cat(gt_full_sequence, dim=0) / 255. | |
| gt_full_sequence.detach().cpu().numpy().tofile(args.checkpoint_dir + "/gt_video.bin") | |
| generated_full_sequence.detach().cpu().numpy().tofile(args.checkpoint_dir + "/generated_video.bin") | |
| # save the generated and groundtruth sequences | |
| # import IPython; IPython.embed() | |
| metrics["fid"].update_list([calculate_fid(gt_full_sequence, generated_full_sequence, device=accelerator.device)]) | |
| metrics["fvd"].update_list([calculate_fvd(gt_full_sequence, generated_full_sequence, device=accelerator.device)]) | |
| for key, val in metrics.items(): | |
| agg_total, agg_count = accelerator.reduce( | |
| torch.tensor([val.total, val.count], device=accelerator.device) | |
| ) | |
| accelerator.print(f"{key}: {agg_total / agg_count:.4f}") | |
| if accelerator.is_main_process: | |
| prefix = "teacher_force" if not args.autoregressive_time else "autoregressive" | |
| for key, val in metrics.items(): | |
| try: | |
| wandb.log({f"{dataset}/{prefix}_{key}": val.mean()}) | |
| wandb.log({f"{prefix}_{key}": val.mean()}) | |
| except Exception as e: | |
| print(e) | |
| wandb.log({f"{dataset}/num_examples": len(val_dataset)}) | |
| wandb.log({f"{dataset}/perturbation_scale": args.perturbation_scale}) | |
| wandb.log({f"model_step": evaluator.model_step}) | |
| # model training steps | |
| dataset_metadata = { | |
| f"{dataset}/dataset_name": f"{dataset}", | |
| f"{dataset}/num_examples": len(val_dataset), | |
| f"{dataset}/num_features": len(val_dataset[0]) if val_dataset else 0, | |
| f"{dataset}/sample_data": val_dataset[0] if len(val_dataset) > 0 else "N/A", | |
| f"{dataset}/model_step": evaluator.model_step | |
| } | |
| for k, v in dataset_metadata.items(): | |
| wandb.run.summary[k] = v | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() | |