PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
fdc0cf4
·
verified ·
1 Parent(s): edf183a

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +6 -5
model_hf.py CHANGED
@@ -6,9 +6,9 @@ import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torch import Tensor
9
- import fairseq
10
  from .config_ssl import SSLConfig
11
  from huggingface_hub import hf_hub_download
 
12
 
13
  ___author__ = "Hemlata Tak"
14
  __email__ = "[email protected]"
@@ -21,10 +21,11 @@ __email__ = "[email protected]"
21
  class SSLModel(nn.Module):
22
  def __init__(self,device):
23
  super(SSLModel, self).__init__()
24
- repo_id = 'ash56/ssl-aasist'
25
- fname = 'xlsr2_300m.pt'
26
- cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
27
- model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
 
28
  self.model = model[0]
29
  self.model_device=device
30
  self.out_dim = 1024
 
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torch import Tensor
 
9
  from .config_ssl import SSLConfig
10
  from huggingface_hub import hf_hub_download
11
+ from transformers import Wav2Vec2ForPreTraining
12
 
13
  ___author__ = "Hemlata Tak"
14
  __email__ = "[email protected]"
 
21
  class SSLModel(nn.Module):
22
  def __init__(self,device):
23
  super(SSLModel, self).__init__()
24
+ # eliminate fairseq dependency
25
+ repo_id = "facebook/wav2vec2-xlsr-300m"
26
+ model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xlsr-300m")
27
+ # cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
28
+ # model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
29
  self.model = model[0]
30
  self.model_device=device
31
  self.out_dim = 1024