old_tok / making_cache.py
tennant's picture
upload
af7c0ce
import numpy as np
import os, pdb, time
import torch_fidelity
import tqdm
import torch
import os.path as osp
import argparse
from omegaconf import OmegaConf
from paintmind.engine.util import instantiate_from_config
@torch.no_grad()
def caching():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='configs/vit_vqgan.yaml')
args = parser.parse_args()
cfg_file = args.cfg
assert osp.exists(cfg_file)
config = OmegaConf.load(cfg_file)
dataset = instantiate_from_config(config.trainer.params.dataset)
model = instantiate_from_config(config.trainer.params.model)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=config.trainer.params.batch_size,
shuffle=False,
num_workers=config.trainer.params.num_workers,
)
# Each batch will give us a (N, C, H, W) tensor of images
# We need to cache them and save them to a pth file
cache_save_file = config.trainer.params.latent_cache_file
cache = []
# import ipdb; ipdb.set_trace()
model.cuda()
model.eval()
for idx, batch in enumerate(tqdm.tqdm(dataloader)):
batch = batch[0].cuda()
latent = model.vae_encode(batch)
cache.append(latent.cpu())
cache = torch.cat(cache, dim=0)
torch.save(cache, cache_save_file)
if __name__ == '__main__':
caching()