PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
aa5a0cc
·
verified ·
1 Parent(s): fdc0cf4

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +3 -2
model_hf.py CHANGED
@@ -22,8 +22,9 @@ class SSLModel(nn.Module):
22
  def __init__(self,device):
23
  super(SSLModel, self).__init__()
24
  # eliminate fairseq dependency
25
- repo_id = "facebook/wav2vec2-xlsr-300m"
26
- model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xlsr-300m")
 
27
  # cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
28
  # model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
29
  self.model = model[0]
 
22
  def __init__(self,device):
23
  super(SSLModel, self).__init__()
24
  # eliminate fairseq dependency
25
+ # facebook/wav2vec2-xls-r-300m
26
+ # repo_id = "facebook/wav2vec2-xlsr-300m"
27
+ model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-xls-r-300m")
28
  # cp_path = hf_hub_download(repo_id=repo_id, filename=fname) # Change the pre-trained XLSR model path.
29
  # model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
30
  self.model = model[0]