vatrpp / configuration_vatrpp.py
vittoriopippi
Error with model_type
0f0df26
from transformers import PretrainedConfig
class VATrPPConfig(PretrainedConfig):
model_type = "vatrpp"
def __init__(self,
feat_model_path='files/resnet_18_pretrained.pth',
label_encoder='default',
save_model_path='saved_models',
dataset='IAM',
english_words_path='files/english_words.txt',
wandb=False,
no_writer_loss=False,
writer_loss_weight=1.0,
no_ocr_loss=False,
img_height=32,
resolution=16,
batch_size=8,
num_examples=15,
num_writers=339,
alphabet='Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%',
special_alphabet='ΑαΒβΓγΔδΕεΖζΗηΘθΙιΚκΛλΜμΝνΞξΟοΠπΡρΣσςΤτΥυΦφΧχΨψΩω',
g_lr=0.00005,
d_lr=0.00001,
w_lr=0.00005,
ocr_lr=0.00005,
epochs=100000,
num_workers=0,
seed=742,
num_words=3,
is_cycle=False,
add_noise=False,
save_model=5,
save_model_history=500,
tag='debug',
device='cuda',
query_input='unifont',
corpus="standard",
text_augment_strength=0.0,
text_aug_type="proportional",
file_suffix=None,
augment_ocr=False,
d_crop_size=None,
**kwargs):
super().__init__(**kwargs)
self.feat_model_path = feat_model_path
self.label_encoder = label_encoder
self.save_model_path = save_model_path
self.dataset = dataset
self.english_words_path = english_words_path
self.wandb = wandb
self.no_writer_loss = no_writer_loss
self.writer_loss_weight = writer_loss_weight
self.no_ocr_loss = no_ocr_loss
self.img_height = img_height
self.resolution = resolution
self.batch_size = batch_size
self.num_examples = num_examples
self.num_writers = num_writers
self.alphabet = alphabet
self.special_alphabet = special_alphabet
self.g_lr = g_lr
self.d_lr = d_lr
self.w_lr = w_lr
self.ocr_lr = ocr_lr
self.epochs = epochs
self.num_workers = num_workers
self.seed = seed
self.num_words = num_words
self.is_cycle = is_cycle
self.add_noise = add_noise
self.save_model = save_model
self.save_model_history = save_model_history
self.tag = tag
self.device = device
self.query_input = query_input
self.corpus = corpus
self.text_augment_strength = text_augment_strength
self.text_aug_type = text_aug_type
self.file_suffix = file_suffix
self.augment_ocr = augment_ocr
self.d_crop_size = d_crop_size