|
from transformers import PretrainedConfig |
|
|
|
class VATrPPConfig(PretrainedConfig): |
|
model_type = "vatrpp" |
|
|
|
def __init__(self, |
|
feat_model_path='files/resnet_18_pretrained.pth', |
|
label_encoder='default', |
|
save_model_path='saved_models', |
|
dataset='IAM', |
|
english_words_path='files/english_words.txt', |
|
wandb=False, |
|
no_writer_loss=False, |
|
writer_loss_weight=1.0, |
|
no_ocr_loss=False, |
|
img_height=32, |
|
resolution=16, |
|
batch_size=8, |
|
num_examples=15, |
|
num_writers=339, |
|
alphabet='Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%', |
|
special_alphabet='ΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω', |
|
g_lr=0.00005, |
|
d_lr=0.00001, |
|
w_lr=0.00005, |
|
ocr_lr=0.00005, |
|
epochs=100000, |
|
num_workers=0, |
|
seed=742, |
|
num_words=3, |
|
is_cycle=False, |
|
add_noise=False, |
|
save_model=5, |
|
save_model_history=500, |
|
tag='debug', |
|
device='cuda', |
|
query_input='unifont', |
|
corpus="standard", |
|
text_augment_strength=0.0, |
|
text_aug_type="proportional", |
|
file_suffix=None, |
|
augment_ocr=False, |
|
d_crop_size=None, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.feat_model_path = feat_model_path |
|
self.label_encoder = label_encoder |
|
self.save_model_path = save_model_path |
|
self.dataset = dataset |
|
self.english_words_path = english_words_path |
|
self.wandb = wandb |
|
self.no_writer_loss = no_writer_loss |
|
self.writer_loss_weight = writer_loss_weight |
|
self.no_ocr_loss = no_ocr_loss |
|
self.img_height = img_height |
|
self.resolution = resolution |
|
self.batch_size = batch_size |
|
self.num_examples = num_examples |
|
self.num_writers = num_writers |
|
self.alphabet = alphabet |
|
self.special_alphabet = special_alphabet |
|
self.g_lr = g_lr |
|
self.d_lr = d_lr |
|
self.w_lr = w_lr |
|
self.ocr_lr = ocr_lr |
|
self.epochs = epochs |
|
self.num_workers = num_workers |
|
self.seed = seed |
|
self.num_words = num_words |
|
self.is_cycle = is_cycle |
|
self.add_noise = add_noise |
|
self.save_model = save_model |
|
self.save_model_history = save_model_history |
|
self.tag = tag |
|
self.device = device |
|
self.query_input = query_input |
|
self.corpus = corpus |
|
self.text_augment_strength = text_augment_strength |
|
self.text_aug_type = text_aug_type |
|
self.file_suffix = file_suffix |
|
self.augment_ocr = augment_ocr |
|
self.d_crop_size = d_crop_size |
|
|