import os from pathlib import Path import torch import torch.utils.data from data.dataset import FidDataset from generate.writer import Writer def generate_fid(args): if 'iam' in args.target_dataset_path.lower(): args.num_writers = 339 elif 'cvl' in args.target_dataset_path.lower(): args.num_writers = 283 else: raise ValueError args.vocab_size = len(args.alphabet) dataset_train = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='train', style_dataset=args.dataset_path) train_loader = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False, collate_fn=dataset_train.collate_fn ) dataset_test = FidDataset(base_path=args.target_dataset_path, num_examples=args.num_examples, collator_resolution=args.resolution, mode='test', style_dataset=args.dataset_path) test_loader = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False, collate_fn=dataset_test.collate_fn ) args.output = 'saved_images' if args.output is None else args.output args.output = Path(args.output) / 'fid' / args.target_dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "") model_folder = args.checkpoint.split("/")[-2] if args.checkpoint.endswith(".pth") else args.checkpoint.split("/")[-1] model_tag = model_folder.split("-")[-1] if "-" in model_folder else "vatr" model_tag += "_" + args.dataset_path.split("/")[-1].replace(".pickle", "").replace("-", "") if not args.all_epochs: writer = Writer(args.checkpoint, args, only_generator=True) if not args.test_only: writer.generate_fid(args.output, train_loader, model_tag=model_tag, split='train', fake_only=args.fake_only, long_tail_only=args.long_tail) writer.generate_fid(args.output, test_loader, model_tag=model_tag, split='test', fake_only=args.fake_only, long_tail_only=args.long_tail) else: epochs = sorted([int(f.split("_")[0]) for f in os.listdir(args.checkpoint) if "_" in f]) generate_real = True for epoch in epochs: checkpoint_path = os.path.join(args.checkpoint, f"{str(epoch).zfill(4)}_model.pth") writer = Writer(checkpoint_path, args, only_generator=True) writer.generate_fid(args.output, test_loader, model_tag=f"{model_tag}_{epoch}", split='test', fake_only=not generate_real, long_tail_only=args.long_tail) generate_real = False print('Done')