import torch import torch.nn as nn from transformers import AutoConfig from DeFTAN2 import DeFTAN2 class DeFTAN2Model(nn.Module): def __init__(self, config): super(DeFTAN2Model, self).__init__() self.model = DeFTAN2(config) def forward(self, x): return self.model(x) @classmethod def from_pretrained(cls, model_path): config = AutoConfig.from_pretrained(model_path) # config.json에서 설정 로드 model = cls(config) state_dict = torch.load(f"{model_path}/deftan2.bin", map_location="cpu") model.load_state_dict(state_dict) return model