File size: 3,154 Bytes
f253188
fa0f216
f253188
fa0f216
968c3ea
fa0f216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import PretrainedConfig, AutoConfig

@AutoConfig.register("vatrpp")
class VATrPPConfig(PretrainedConfig):
    model_type = "vatrpp"

    def __init__(self,
                 feat_model_path='files/resnet_18_pretrained.pth',
                 label_encoder='default',
                 save_model_path='saved_models',
                 dataset='IAM',
                 english_words_path='files/english_words.txt',
                 wandb=False,
                 no_writer_loss=False,
                 writer_loss_weight=1.0,
                 no_ocr_loss=False,
                 img_height=32,
                 resolution=16,
                 batch_size=8,
                 num_examples=15,
                 num_writers=339,
                 alphabet='Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%',
                 special_alphabet='ΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω',
                 g_lr=0.00005,
                 d_lr=0.00001,
                 w_lr=0.00005,
                 ocr_lr=0.00005,
                 epochs=100000,
                 num_workers=0,
                 seed=742,
                 num_words=3,
                 is_cycle=False,
                 add_noise=False,
                 save_model=5,
                 save_model_history=500,
                 tag='debug',
                 device='cuda',
                 query_input='unifont',
                 corpus="standard",
                 text_augment_strength=0.0,
                 text_aug_type="proportional",
                 file_suffix=None,
                 augment_ocr=False,
                 d_crop_size=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.feat_model_path = feat_model_path
        self.label_encoder = label_encoder
        self.save_model_path = save_model_path
        self.dataset = dataset
        self.english_words_path = english_words_path
        self.wandb = wandb
        self.no_writer_loss = no_writer_loss
        self.writer_loss_weight = writer_loss_weight
        self.no_ocr_loss = no_ocr_loss
        self.img_height = img_height
        self.resolution = resolution
        self.batch_size = batch_size
        self.num_examples = num_examples
        self.num_writers = num_writers
        self.alphabet = alphabet
        self.special_alphabet = special_alphabet
        self.g_lr = g_lr
        self.d_lr = d_lr
        self.w_lr = w_lr
        self.ocr_lr = ocr_lr
        self.epochs = epochs
        self.num_workers = num_workers
        self.seed = seed
        self.num_words = num_words
        self.is_cycle = is_cycle
        self.add_noise = add_noise
        self.save_model = save_model
        self.save_model_history = save_model_history
        self.tag = tag
        self.device = device
        self.query_input = query_input
        self.corpus = corpus
        self.text_augment_strength = text_augment_strength
        self.text_aug_type = text_aug_type
        self.file_suffix = file_suffix
        self.augment_ocr = augment_ocr
        self.d_crop_size = d_crop_size