vatrpp / generate /ocr.py
vittoriopippi
Initial commit
fa0f216
import os
import shutil
import cv2
import msgpack
import torch
from data.dataset import CollectionTextDataset, TextDataset, FolderDataset, FidDataset, get_dataset_path
from generate.writer import Writer
from util.text import get_generator
def generate_ocr(args):
"""
Generate OCR training data. Words generated are from given text generator.
"""
dataset = CollectionTextDataset(
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
collator_resolution=args.resolution, validation=True
)
args.num_writers = dataset.num_writers
writer = Writer(args.checkpoint, args, only_generator=True)
generator = get_generator(args)
writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, text_generator=generator)
def generate_ocr_reference(args):
"""
Generate OCR training data. Words generated are words from given dataset. Reference words are also saved.
"""
dataset = CollectionTextDataset(
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
collator_resolution=args.resolution, validation=True
)
#dataset = FidDataset(get_dataset_path(args.dataset, 32, args.file_suffix, 'files'), mode='test', collator_resolution=args.resolution)
args.num_writers = dataset.num_writers
writer = Writer(args.checkpoint, args, only_generator=True)
writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, long_tail=args.long_tail)
def generate_ocr_msgpack(args):
"""
Generate OCR dataset. Words generated are specified in given msgpack file
"""
dataset = FolderDataset(args.dataset_path)
args.num_writers = 339
if args.charset_file:
charset = msgpack.load(open(args.charset_file, 'rb'), use_list=False, strict_map_key=False)
args.alphabet = "".join(charset['char2idx'].keys())
writer = Writer(args.checkpoint, args, only_generator=True)
lines = msgpack.load(open(args.text_path, 'rb'), use_list=False)
print(f"Generating {len(lines)} to {args.output}")
for i, (filename, target) in enumerate(lines):
if not os.path.exists(os.path.join(args.output, filename)):
style = torch.unsqueeze(dataset.sample_style()['simg'], dim=0).to(args.device)
fake = writer.create_fake_sentence(style, target, at_once=True)
cv2.imwrite(os.path.join(args.output, filename), fake)
print(f"Done")