PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
edf183a
·
verified ·
1 Parent(s): cfe653a

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +4 -2
model_hf.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
  from torch import Tensor
9
  import fairseq
10
  from .config_ssl import SSLConfig
 
11
 
12
  ___author__ = "Hemlata Tak"
13
  __email__ = "[email protected]"
@@ -20,8 +21,9 @@ __email__ = "[email protected]"
20
  class SSLModel(nn.Module):
21
  def __init__(self,device):
22
  super(SSLModel, self).__init__()
23
-
24
- cp_path = './xlsr2_300m.pt' # Change the pre-trained XLSR model path.
 
25
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
26
  self.model = model[0]
27
  self.model_device=device
 
8
  from torch import Tensor
9
  import fairseq
10
  from .config_ssl import SSLConfig
11
+ from huggingface_hub import hf_hub_download
12
 
13
  ___author__ = "Hemlata Tak"
14
  __email__ = "[email protected]"
 
21
  class SSLModel(nn.Module):
22
  def __init__(self,device):
23
  super(SSLModel, self).__init__()
24
+ repo_id = 'ash56/ssl-aasist'
25
+ fname = 'xlsr2_300m.pt'
26
+ cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
27
  model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
28
  self.model = model[0]
29
  self.model_device=device