from transformers import PretrainedConfig class VATrPPConfig(PretrainedConfig): model_type = "vatrpp" def __init__(self, feat_model_path='files/resnet_18_pretrained.pth', label_encoder='default', save_model_path='saved_models', dataset='IAM', english_words_path='files/english_words.txt', wandb=False, no_writer_loss=False, writer_loss_weight=1.0, no_ocr_loss=False, img_height=32, resolution=16, batch_size=8, num_examples=15, num_writers=339, alphabet='Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%', special_alphabet='ΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω', g_lr=0.00005, d_lr=0.00001, w_lr=0.00005, ocr_lr=0.00005, epochs=100000, num_workers=0, seed=742, num_words=3, is_cycle=False, add_noise=False, save_model=5, save_model_history=500, tag='debug', device='cuda', query_input='unifont', corpus="standard", text_augment_strength=0.0, text_aug_type="proportional", file_suffix=None, augment_ocr=False, d_crop_size=None, **kwargs): super().__init__(**kwargs) self.feat_model_path = feat_model_path self.label_encoder = label_encoder self.save_model_path = save_model_path self.dataset = dataset self.english_words_path = english_words_path self.wandb = wandb self.no_writer_loss = no_writer_loss self.writer_loss_weight = writer_loss_weight self.no_ocr_loss = no_ocr_loss self.img_height = img_height self.resolution = resolution self.batch_size = batch_size self.num_examples = num_examples self.num_writers = num_writers self.alphabet = alphabet self.special_alphabet = special_alphabet self.g_lr = g_lr self.d_lr = d_lr self.w_lr = w_lr self.ocr_lr = ocr_lr self.epochs = epochs self.num_workers = num_workers self.seed = seed self.num_words = num_words self.is_cycle = is_cycle self.add_noise = add_noise self.save_model = save_model self.save_model_history = save_model_history self.tag = tag self.device = device self.query_input = query_input self.corpus = corpus self.text_augment_strength = text_augment_strength self.text_aug_type = text_aug_type self.file_suffix = file_suffix self.augment_ocr = augment_ocr self.d_crop_size = d_crop_size