|
import os |
|
import shutil |
|
|
|
import cv2 |
|
import msgpack |
|
import torch |
|
|
|
from data.dataset import CollectionTextDataset, TextDataset, FolderDataset, FidDataset, get_dataset_path |
|
from generate.writer import Writer |
|
from util.text import get_generator |
|
|
|
|
|
def generate_ocr(args): |
|
""" |
|
Generate OCR training data. Words generated are from given text generator. |
|
""" |
|
dataset = CollectionTextDataset( |
|
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, |
|
collator_resolution=args.resolution, validation=True |
|
) |
|
args.num_writers = dataset.num_writers |
|
|
|
writer = Writer(args.checkpoint, args, only_generator=True) |
|
|
|
generator = get_generator(args) |
|
|
|
writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, text_generator=generator) |
|
|
|
|
|
def generate_ocr_reference(args): |
|
""" |
|
Generate OCR training data. Words generated are words from given dataset. Reference words are also saved. |
|
""" |
|
dataset = CollectionTextDataset( |
|
args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples, |
|
collator_resolution=args.resolution, validation=True |
|
) |
|
|
|
|
|
|
|
args.num_writers = dataset.num_writers |
|
|
|
writer = Writer(args.checkpoint, args, only_generator=True) |
|
|
|
writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, long_tail=args.long_tail) |
|
|
|
|
|
def generate_ocr_msgpack(args): |
|
""" |
|
Generate OCR dataset. Words generated are specified in given msgpack file |
|
""" |
|
dataset = FolderDataset(args.dataset_path) |
|
args.num_writers = 339 |
|
|
|
if args.charset_file: |
|
charset = msgpack.load(open(args.charset_file, 'rb'), use_list=False, strict_map_key=False) |
|
args.alphabet = "".join(charset['char2idx'].keys()) |
|
|
|
writer = Writer(args.checkpoint, args, only_generator=True) |
|
|
|
lines = msgpack.load(open(args.text_path, 'rb'), use_list=False) |
|
|
|
print(f"Generating {len(lines)} to {args.output}") |
|
|
|
for i, (filename, target) in enumerate(lines): |
|
if not os.path.exists(os.path.join(args.output, filename)): |
|
style = torch.unsqueeze(dataset.sample_style()['simg'], dim=0).to(args.device) |
|
fake = writer.create_fake_sentence(style, target, at_once=True) |
|
|
|
cv2.imwrite(os.path.join(args.output, filename), fake) |
|
|
|
print(f"Done") |
|
|