vatrpp / generate /fid.py
vittoriopippi
Initial commit
fa0f216
raw
history blame
2.76 kB
import os
from pathlib import Path
import torch
import torch.utils.data
from data.dataset import FidDataset
from generate.writer import Writer
def generate_fid(args):
if 'iam' in args.target_dataset_path.lower():
args.num_writers = 339
elif 'cvl' in args.target_dataset_path.lower():
args.num_writers = 283
else:
raise ValueError
args.vocab_size = len(args.alphabet)
dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path)
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True, drop_last=False,
collate_fn=dataset_train.collate_fn
)
dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path)
test_loader = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True, drop_last=False,
collate_fn=dataset_test.collate_fn
)
args.output = 'saved_images' if args.output is None else args.output
args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")
model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1]
model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr"
model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "")
if not args.all_epochs:
writer = Writer(args.checkpoint, args, only_generator=True)
if not args.test_only:
writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail)
writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail)
else:
epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f])
generate_real = True
for epoch in epochs:
checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth")
writer = Writer(checkpoint_path, args, only_generator=True)
writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail)
generate_real = False
print('Done')