Update components/semantic_extractor/ssl_model.py
Browse files
components/semantic_extractor/ssl_model.py
CHANGED
|
@@ -47,6 +47,7 @@ def get_ssl_model(ckpt_path, km_path, device='cuda', type='xlsr'):
|
|
| 47 |
model = model.requires_grad_(False)
|
| 48 |
else:
|
| 49 |
raise NotImplementedError
|
|
|
|
| 50 |
km_model = ApplyKmeans(km_path, device)
|
| 51 |
return model, km_model
|
| 52 |
|
|
|
|
| 47 |
model = model.requires_grad_(False)
|
| 48 |
else:
|
| 49 |
raise NotImplementedError
|
| 50 |
+
km_path = hf_hub_download(repo_id="yaoxunji/gense", filename="wavlm_km.mdl")
|
| 51 |
km_model = ApplyKmeans(km_path, device)
|
| 52 |
return model, km_model
|
| 53 |
|