File size: 611 Bytes
bcb1848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
모델들 불러오는 모듈
"""
import torch
# from .load_model import KCSN
# from .arguments import get_train_args


# args = get_train_args()

def load_ner(path ='model/NER.pth'):
    """
    NER 모델
    """
    checkpoint = torch.load(path)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['model_state_dict'])

    return model, checkpoint


# def load_fs(path = 'model/FS.pth'):
#     """
#     Find Speaker 모델
#     """
#     model = KCSN(args)
#     checkpoint = torch.load(path)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     return model, checkpoint