BorisovMaksim commited on
Commit
08d9656
·
1 Parent(s): e95cc03

fix app.py

Browse files

create upload tool

Files changed (3) hide show
  1. app.py +1 -2
  2. train.py +2 -1
  3. upload_to_HF.py +56 -0
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  from pathlib import Path
5
  from denoisers.SpectralGating import SpectralGating
6
 
7
-
8
 
9
  def denoising_transform(audio):
10
  src_path = Path(__file__).parent.resolve() / Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
@@ -32,6 +32,5 @@ demo = gr.Interface(
32
  )
33
 
34
  if __name__ == "__main__":
35
- model = SpectralGating()
36
  demo.launch()
37
 
 
4
  from pathlib import Path
5
  from denoisers.SpectralGating import SpectralGating
6
 
7
+ model = SpectralGating()
8
 
9
  def denoising_transform(audio):
10
  src_path = Path(__file__).parent.resolve() / Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
 
32
  )
33
 
34
  if __name__ == "__main__":
 
35
  demo.launch()
36
 
train.py CHANGED
@@ -34,7 +34,8 @@ def init_wandb(cfg):
34
  def train(cfg: DictConfig):
35
  device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
36
  init_wandb(cfg)
37
- checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
 
38
  metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
39
 
40
  model = get_model(cfg['model']).to(device)
 
34
  def train(cfg: DictConfig):
35
  device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')
36
  init_wandb(cfg)
37
+ checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name,
38
+ decreasing=False)
39
  metrics = Metrics(source_rate=cfg['dataloader']['sample_rate']).to(device)
40
 
41
  model = get_model(cfg['model']).to(device)
upload_to_HF.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ from huggingface_hub import HfApi
4
+ from pathlib import Path
5
+ import huggingface_hub
6
+ import ssl
7
+ import os
8
+ os.environ['CURL_CA_BUNDLE'] = ''
9
+
10
+ ssl._create_default_https_context = ssl._create_unverified_context
11
+
12
+ class Uploader:
13
+ def __init__(self, entity, project, run_name, repo_id, username):
14
+ self.entity = entity
15
+ self.project = project
16
+ self.run_name = run_name
17
+ self.hf_api = HfApi()
18
+ self.wandb_api = wandb.Api()
19
+ self.repo_id = repo_id
20
+ self.username = username
21
+ huggingface_hub.login(os.environ.get('HUGGINGFACE_TOKEN'))
22
+
23
+ def get_model_from_wandb_run(self):
24
+ runs = self.wandb_api.runs(f"{self.entity}/{self.project}",
25
+ # order='+summary_metrics.train_pesq'
26
+ )
27
+ run = [run for run in runs if run.name == self.run_name][0]
28
+ artifacts = run.logged_artifacts()
29
+ best_model = [artifact for artifact in artifacts if artifact.type == 'model'][0]
30
+ artifact_dir = best_model.download()
31
+ model_path = list(Path(artifact_dir).glob("*.pt"))[0].absolute().as_posix()
32
+ print(f"Model validation score = {best_model.metadata['Validation score']}")
33
+ return model_path
34
+
35
+ def upload_to_HF(self):
36
+ model_path = self.get_model_from_wandb_run()
37
+ self.hf_api.upload_file(
38
+ path_or_fileobj=model_path,
39
+ path_in_repo=Path(model_path).name,
40
+ repo_id=f'{self.username}/{self.repo_id}',
41
+ )
42
+
43
+ def create_repo(self):
44
+ self.hf_api.create_repo(repo_id=self.repo_id, exist_ok=True)
45
+
46
+
47
+
48
+ if __name__ == '__main__':
49
+ uploader = Uploader(entity='borisovmaksim',
50
+ project='denoising',
51
+ run_name='wav_normalization',
52
+ repo_id='demucs',
53
+ username='BorisovMaksim')
54
+ uploader.create_repo()
55
+ uploader.upload_to_HF()
56
+