|
import numpy as np |
|
import os, pdb, time |
|
import torch_fidelity |
|
import tqdm |
|
import torch |
|
import os.path as osp |
|
import argparse |
|
from omegaconf import OmegaConf |
|
from paintmind.engine.util import instantiate_from_config |
|
|
|
|
|
@torch.no_grad() |
|
def caching(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--cfg', type=str, default='configs/vit_vqgan.yaml') |
|
args = parser.parse_args() |
|
|
|
cfg_file = args.cfg |
|
assert osp.exists(cfg_file) |
|
config = OmegaConf.load(cfg_file) |
|
dataset = instantiate_from_config(config.trainer.params.dataset) |
|
model = instantiate_from_config(config.trainer.params.model) |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=config.trainer.params.batch_size, |
|
shuffle=False, |
|
num_workers=config.trainer.params.num_workers, |
|
) |
|
|
|
|
|
cache_save_file = config.trainer.params.latent_cache_file |
|
cache = [] |
|
|
|
model.cuda() |
|
model.eval() |
|
for idx, batch in enumerate(tqdm.tqdm(dataloader)): |
|
batch = batch[0].cuda() |
|
latent = model.vae_encode(batch) |
|
cache.append(latent.cpu()) |
|
cache = torch.cat(cache, dim=0) |
|
torch.save(cache, cache_save_file) |
|
|
|
if __name__ == '__main__': |
|
|
|
caching() |
|
|
|
|