|
import torch
|
|
import torch.nn as nn
|
|
from transformers import AutoConfig
|
|
from DeFTAN2 import DeFTAN2
|
|
|
|
class DeFTAN2Model(nn.Module):
|
|
def __init__(self, config):
|
|
super(DeFTAN2Model, self).__init__()
|
|
self.model = DeFTAN2(config)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_path):
|
|
config = AutoConfig.from_pretrained(model_path)
|
|
model = cls(config)
|
|
state_dict = torch.load(f"{model_path}/deftan2.bin", map_location="cpu")
|
|
model.load_state_dict(state_dict)
|
|
return model |