|
|
|
from transformers import PretrainedConfig, AutoConfig |
|
from typing import List |
|
|
|
|
|
class BioNextTaggerConfig(PretrainedConfig): |
|
model_type = "crf-tagger" |
|
|
|
def __init__( |
|
self, |
|
augmentation = "unk", |
|
context_size = 64, |
|
percentage_tags = 0.2, |
|
p_augmentation = 0.5, |
|
crf_reduction = "mean", |
|
version="0.1.1", |
|
**kwargs, |
|
): |
|
self.version = version |
|
self.augmentation = augmentation |
|
self.context_size = context_size |
|
self.percentage_tags = percentage_tags |
|
self.p_augmentation = p_augmentation |
|
self.crf_reduction = crf_reduction |
|
super().__init__(**kwargs) |
|
|
|
def get_backbonemodel_config(self): |
|
backbonemodel_cfg = AutoConfig.from_pretrained(self._name_or_path) |
|
for k in backbonemodel_cfg.to_dict(): |
|
if hasattr(self, k): |
|
setattr(backbonemodel_cfg,k, getattr(self,k)) |
|
|
|
return backbonemodel_cfg |
|
|