File size: 2,567 Bytes
fa0f216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import shutil

import cv2
import msgpack
import torch

from data.dataset import CollectionTextDataset, TextDataset, FolderDataset, FidDataset, get_dataset_path
from generate.writer import Writer
from util.text import get_generator


def generate_ocr(args):
    """
    Generate OCR training data. Words generated are from given text generator.
    """
    dataset = CollectionTextDataset(
        args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
        collator_resolution=args.resolution, validation=True
    )
    args.num_writers = dataset.num_writers

    writer = Writer(args.checkpoint, args, only_generator=True)

    generator = get_generator(args)

    writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, text_generator=generator)


def generate_ocr_reference(args):
    """
    Generate OCR training data. Words generated are words from given dataset. Reference words are also saved.
    """
    dataset = CollectionTextDataset(
        args.dataset, 'files', TextDataset, file_suffix=args.file_suffix, num_examples=args.num_examples,
        collator_resolution=args.resolution, validation=True
    )

    #dataset = FidDataset(get_dataset_path(args.dataset, 32, args.file_suffix, 'files'), mode='test', collator_resolution=args.resolution)

    args.num_writers = dataset.num_writers

    writer = Writer(args.checkpoint, args, only_generator=True)

    writer.generate_ocr(dataset, args.count, interpolate_style=args.interp_styles, output_folder=args.output, long_tail=args.long_tail)


def generate_ocr_msgpack(args):
    """
    Generate OCR dataset. Words generated are specified in given msgpack file
    """
    dataset = FolderDataset(args.dataset_path)
    args.num_writers = 339

    if args.charset_file:
        charset = msgpack.load(open(args.charset_file, 'rb'), use_list=False, strict_map_key=False)
        args.alphabet = "".join(charset['char2idx'].keys())

    writer = Writer(args.checkpoint, args, only_generator=True)

    lines = msgpack.load(open(args.text_path, 'rb'), use_list=False)

    print(f"Generating {len(lines)} to {args.output}")

    for i, (filename, target) in enumerate(lines):
        if not os.path.exists(os.path.join(args.output, filename)):
            style = torch.unsqueeze(dataset.sample_style()['simg'], dim=0).to(args.device)
            fake = writer.create_fake_sentence(style, target, at_once=True)

            cv2.imwrite(os.path.join(args.output, filename), fake)

    print(f"Done")