Spaces:
Runtime error
Runtime error
Commit
·
08d9656
1
Parent(s):
e95cc03
fix app.py
Browse filescreate upload tool
- app.py +1 -2
- train.py +2 -1
- upload_to_HF.py +56 -0
app.py
CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
|
|
4 |
from pathlib import Path
|
5 |
from denoisers.SpectralGating import SpectralGating
|
6 |
|
7 |
-
|
8 |
|
9 |
def denoising_transform(audio):
|
10 |
src_path = Path(__file__).parent.resolve() / Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
|
@@ -32,6 +32,5 @@ demo = gr.Interface(
|
|
32 |
)
|
33 |
|
34 |
if __name__ == "__main__":
|
35 |
-
model = SpectralGating()
|
36 |
demo.launch()
|
37 |
|
|
|
4 |
from pathlib import Path
|
5 |
from denoisers.SpectralGating import SpectralGating
|
6 |
|
7 |
+
model = SpectralGating()
|
8 |
|
9 |
def denoising_transform(audio):
|
10 |
src_path = Path(__file__).parent.resolve() / Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
|
|
|
32 |
)
|
33 |
|
34 |
if __name__ == "__main__":
|
|
|
35 |
demo.launch()
|
36 |
|
train.py
CHANGED
@@ -34,7 +34,8 @@ def init_wandb(cfg):
|
|
34 |
def train(cfg: DictConfig):
|
35 |
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
36 |
init_wandb(cfg)
|
37 |
-
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name
|
|
|
38 |
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
|
39 |
|
40 |
model = get_model(cfg['model']).to(device)
|
|
|
34 |
def train(cfg: DictConfig):
|
35 |
device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
|
36 |
init_wandb(cfg)
|
37 |
+
checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
|
38 |
+
decreasing=False)
|
39 |
metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
|
40 |
|
41 |
model = get_model(cfg['model']).to(device)
|
upload_to_HF.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wandb
|
3 |
+
from huggingface_hub import HfApi
|
4 |
+
from pathlib import Path
|
5 |
+
import huggingface_hub
|
6 |
+
import ssl
|
7 |
+
import os
|
8 |
+
os.environ['CURL_CA_BUNDLE'] = ''
|
9 |
+
|
10 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
11 |
+
|
12 |
+
class Uploader:
|
13 |
+
def __init__(self, entity, project, run_name, repo_id, username):
|
14 |
+
self.entity = entity
|
15 |
+
self.project = project
|
16 |
+
self.run_name = run_name
|
17 |
+
self.hf_api = HfApi()
|
18 |
+
self.wandb_api = wandb.Api()
|
19 |
+
self.repo_id = repo_id
|
20 |
+
self.username = username
|
21 |
+
huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN'))
|
22 |
+
|
23 |
+
def get_model_from_wandb_run(self):
|
24 |
+
runs = self.wandb_api.runs(f"{self.entity}/{self.project}",
|
25 |
+
# order='+summary_metrics.train_pesq'
|
26 |
+
)
|
27 |
+
run = [run for run in runs if run.name == self.run_name][0]
|
28 |
+
artifacts = run.logged_artifacts()
|
29 |
+
best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
|
30 |
+
artifact_dir = best_model.download()
|
31 |
+
model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix()
|
32 |
+
print(f"Model validation score = {best_model.metadata['Validation score']}")
|
33 |
+
return model_path
|
34 |
+
|
35 |
+
def upload_to_HF(self):
|
36 |
+
model_path = self.get_model_from_wandb_run()
|
37 |
+
self.hf_api.upload_file(
|
38 |
+
path_or_fileobj=model_path,
|
39 |
+
path_in_repo=Path(model_path).name,
|
40 |
+
repo_id=f'{self.username}/{self.repo_id}',
|
41 |
+
)
|
42 |
+
|
43 |
+
def create_repo(self):
|
44 |
+
self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True)
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
uploader = Uploader(entity='borisovmaksim',
|
50 |
+
project='denoising',
|
51 |
+
run_name='wav_normalization',
|
52 |
+
repo_id='demucs',
|
53 |
+
username='BorisovMaksim')
|
54 |
+
uploader.create_repo()
|
55 |
+
uploader.upload_to_HF()
|
56 |
+
|