|
import os |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.utils.data |
|
|
|
from data.dataset import FidDataset |
|
from generate.writer import Writer |
|
|
|
|
|
def generate_fid(args): |
|
if 'iam' in args.target_dataset_path.lower(): |
|
args.num_writers = 339 |
|
elif 'cvl' in args.target_dataset_path.lower(): |
|
args.num_writers = 283 |
|
else: |
|
raise ValueError |
|
|
|
args.vocab_size = len(args.alphabet) |
|
|
|
dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path) |
|
train_loader = torch.utils.data.DataLoader( |
|
dataset_train, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=args.num_workers, |
|
pin_memory=True, drop_last=False, |
|
collate_fn=dataset_train.collate_fn |
|
) |
|
|
|
dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path) |
|
test_loader = torch.utils.data.DataLoader( |
|
dataset_test, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=0, |
|
pin_memory=True, drop_last=False, |
|
collate_fn=dataset_test.collate_fn |
|
) |
|
|
|
args.output = 'saved_images' if args.output is None else args.output |
|
args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "") |
|
|
|
model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1] |
|
model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr" |
|
model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "") |
|
|
|
if not args.all_epochs: |
|
writer = Writer(args.checkpoint, args, only_generator=True) |
|
if not args.test_only: |
|
writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail) |
|
writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail) |
|
else: |
|
epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f]) |
|
generate_real = True |
|
|
|
for epoch in epochs: |
|
checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth") |
|
writer = Writer(checkpoint_path, args, only_generator=True) |
|
writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail) |
|
generate_real = False |
|
|
|
print('Done') |
|
|