|
import os |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
|
|
from data.dataset import CollectionTextDataset, TextDataset |
|
from models.model import VATr |
|
from util.loading import load_checkpoint, load_generator |
|
|
|
|
|
def generate_page(args): |
|
args.output = 'vatr' if args.output is None else args.output |
|
|
|
args.vocab_size = len(args.alphabet) |
|
|
|
dataset = CollectionTextDataset( |
|
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, |
|
collator_resolution=args.resolution |
|
) |
|
datasetval = CollectionTextDataset( |
|
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, |
|
collator_resolution=args.resolution, validation=True |
|
) |
|
|
|
args.num_writers = dataset.num_writers |
|
|
|
model = VATr(args) |
|
checkpoint = torch.load(args.checkpoint, map_location=args.device) |
|
model = load_generator(model, checkpoint) |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=8, |
|
shuffle=True, |
|
num_workers=0, |
|
pin_memory=True, drop_last=True, |
|
collate_fn=dataset.collate_fn) |
|
|
|
val_loader = torch.utils.data.DataLoader( |
|
datasetval, |
|
batch_size=8, |
|
shuffle=True, |
|
num_workers=0, |
|
pin_memory=True, drop_last=True, |
|
collate_fn=datasetval.collate_fn) |
|
|
|
data_train = next(iter(train_loader)) |
|
data_val = next(iter(val_loader)) |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
page = model._generate_page(data_train['simg'].to(args.device), data_val['swids']) |
|
page_val = model._generate_page(data_val['simg'].to(args.device), data_val['swids']) |
|
|
|
cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_train.png"), (page * 255).astype(np.uint8)) |
|
cv2.imwrite(os.path.join("saved_images", "pages", f"{args.output}_val.png"), (page_val * 255).astype(np.uint8)) |
|
|