import torch from transformers import PreTrainedModel from .config import AwesomeConfig class AwesomeModel(PreTrainedModel): config_class = AwesomeConfig base_model_prefix = "base" def __init__(self, config): super().__init__(config) self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size) def forward(self, x): return self.linear(x)