DeFTAN-II / DeFTAN2Model.py
donghoney0416's picture
Upload DeFTAN2Model.py
9fd08f7 verified
raw
history blame
646 Bytes
import torch
import torch.nn as nn
from transformers import AutoConfig
from DeFTAN2 import DeFTAN2
class DeFTAN2Model(nn.Module):
def __init__(self, config):
super(DeFTAN2Model, self).__init__()
self.model = DeFTAN2(config)
def forward(self, x):
return self.model(x)
@classmethod
def from_pretrained(cls, model_path):
config = AutoConfig.from_pretrained(model_path) # config.json์—์„œ ์„ค์ • ๋กœ๋“œ
model = cls(config)
state_dict = torch.load(f"{model_path}/deftan2.bin", map_location="cpu")
model.load_state_dict(state_dict)
return model