donghoney0416 commited on
Commit
9fd08f7
·
verified ·
1 Parent(s): 795921a

Upload DeFTAN2Model.py

Browse files
Files changed (1) hide show
  1. DeFTAN2Model.py +20 -0
DeFTAN2Model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoConfig
4
+ from DeFTAN2 import DeFTAN2
5
+
6
+ class DeFTAN2Model(nn.Module):
7
+ def __init__(self, config):
8
+ super(DeFTAN2Model, self).__init__()
9
+ self.model = DeFTAN2(config)
10
+
11
+ def forward(self, x):
12
+ return self.model(x)
13
+
14
+ @classmethod
15
+ def from_pretrained(cls, model_path):
16
+ config = AutoConfig.from_pretrained(model_path) # config.json에서 설정 로드
17
+ model = cls(config)
18
+ state_dict = torch.load(f"{model_path}/deftan2.bin", map_location="cpu")
19
+ model.load_state_dict(state_dict)
20
+ return model