File size: 611 Bytes
bcb1848 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
"""
모델들 불러오는 모듈
"""
import torch
# from .load_model import KCSN
# from .arguments import get_train_args
# args = get_train_args()
def load_ner(path ='model/NER.pth'):
"""
NER 모델
"""
checkpoint = torch.load(path)
model = checkpoint['model']
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint
# def load_fs(path = 'model/FS.pth'):
# """
# Find Speaker 모델
# """
# model = KCSN(args)
# checkpoint = torch.load(path)
# model.load_state_dict(checkpoint['model_state_dict'])
# return model, checkpoint
|