BorisovMaksim commited on
Commit
a6a74d4
·
1 Parent(s): d38178c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -23
app.py CHANGED
@@ -8,9 +8,11 @@ from denoisers.demucs import Demucs
8
  import torch
9
  import torchaudio
10
  import yaml
 
11
 
12
  import os
13
  os.environ['CURL_CA_BUNDLE'] = ''
 
14
 
15
 
16
  def denoising_transform(audio, model):
@@ -19,16 +21,16 @@ def denoising_transform(audio, model):
19
  src_path.parent.mkdir(exist_ok=True, parents=True)
20
  tgt_path.parent.mkdir(exist_ok=True, parents=True)
21
  (ffmpeg.input(audio)
22
- .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=22050)
23
  .run()
24
  )
25
- wav, rate = torchaudio.load(audio)
26
  reduced_noise = model.predict(wav)
27
  torchaudio.save(tgt_path, reduced_noise, rate)
28
- return tgt_path
29
 
30
 
31
- def run_app(model_filename, config_filename):
32
  model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
33
  config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename)
34
  with open(config_path, 'r') as f:
@@ -37,26 +39,96 @@ def run_app(model_filename, config_filename):
37
  checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
38
  model.load_state_dict(checkpoint['model_state_dict'])
39
 
40
- interface_demucs = gr.Interface(
41
- fn=lambda x: denoising_transform(x, model),
42
- inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
43
- outputs=gr.Audio(label="Demucs", type='filepath'),
44
- allow_flagging='never'
45
- )
46
- interface_spectral_gating = gr.Interface(
47
- fn=lambda x: denoising_transform(x, SpectralGating()),
48
- inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'),
49
- outputs=gr.Audio(label="Spectral Gating", type='filepath'),
50
- allow_flagging='never'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
- gr.Parallel(interface_demucs, interface_spectral_gating,
53
- title="Denoising",
54
- examples=[[path] for path in Path("testing/wavs/").glob("*.wav")]
55
- ).launch(server_name='0.0.0.0',
56
- server_port=7860)
57
 
58
 
59
  if __name__ == "__main__":
60
- model_filename = "paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch45.pt"
61
- config_filename = "paper_replica_10_epoch/config.yaml"
62
- run_app(model_filename, config_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import torch
9
  import torchaudio
10
  import yaml
11
+ import argparse
12
 
13
  import os
14
  os.environ['CURL_CA_BUNDLE'] = ''
15
+ SAMPLE_RATE = 32000
16
 
17
 
18
  def denoising_transform(audio, model):
 
21
  src_path.parent.mkdir(exist_ok=True, parents=True)
22
  tgt_path.parent.mkdir(exist_ok=True, parents=True)
23
  (ffmpeg.input(audio)
24
+ .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=SAMPLE_RATE)
25
  .run()
26
  )
27
+ wav, rate = torchaudio.load(src_path)
28
  reduced_noise = model.predict(wav)
29
  torchaudio.save(tgt_path, reduced_noise, rate)
30
+ return src_path, tgt_path
31
 
32
 
33
+ def run_app(model_filename, config_filename, port, concurrency_count, max_size):
34
  model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
35
  config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename)
36
  with open(config_path, 'r') as f:
 
39
  checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
40
  model.load_state_dict(checkpoint['model_state_dict'])
41
 
42
+ title = "Chinese-to-English Direct Speech-to-Speech Translation (BETA)"
43
+
44
+
45
+ with gr.Blocks(title=title) as app:
46
+ with gr.Row():
47
+ with gr.Column():
48
+ gr.Markdown(
49
+ """
50
+ # Denoising
51
+ ## Instruction: \n
52
+ 1. Press "Record from microphone"
53
+ 2. Press "Stop recording"
54
+ 3. Press "Enhance" \n
55
+ - You can switch to the tab "File" to upload a prerecorded .wav audio instead of recording from microphone.
56
+ """
57
+ )
58
+ with gr.Tab("Microphone"):
59
+ microphone = gr.Audio(label="Source Audio", source="microphone", type='filepath')
60
+ with gr.Row():
61
+ microphone_button = gr.Button("Enhance", variant="primary")
62
+ with gr.Tab("File"):
63
+ upload = gr.Audio(label="Upload Audio", source="upload", type='filepath')
64
+ with gr.Row():
65
+ upload_button = gr.Button("Enhance", variant="primary")
66
+ clear_btn = gr.Button("Clear")
67
+ gr.Examples(examples=[[path] for path in Path("testing/wavs/").glob("*.wav")],
68
+ inputs=[microphone, upload])
69
+
70
+ with gr.Column():
71
+ outputs = [gr.Audio(label="Input Audio", type='filepath'),
72
+ gr.Audio(label="Demucs Enhancement", type='filepath'),
73
+ gr.Audio(label="Spectral Gating Enhancement", type='filepath')
74
+ ]
75
+
76
+ def submit(audio):
77
+ src_path, demucs_tgt_path = denoising_transform(audio, model)
78
+ _, spectral_gating_tgt_path = denoising_transform(audio, SpectralGating())
79
+ return src_path, demucs_tgt_path, spectral_gating_tgt_path, gr.update(visible=False), gr.update(visible=False)
80
+
81
+
82
+
83
+ microphone_button.click(
84
+ submit,
85
+ microphone,
86
+ outputs + [microphone, upload]
87
+ )
88
+ upload_button.click(
89
+ submit,
90
+ upload,
91
+ outputs + [microphone, upload]
92
+ )
93
+
94
+
95
+ def restart():
96
+ return microphone.update(visible=True, value=None), upload.update(visible=True, value=None), None, None, None
97
+
98
+ clear_btn.click(restart, inputs=[], outputs=[microphone, upload] + outputs)
99
+
100
+ app.queue(concurrency_count=concurrency_count, max_size=max_size)
101
+
102
+ app.launch(
103
+ ssl_verify=False,
104
+ server_name='0.0.0.0',
105
+ server_port=port,
106
+ ssl_keyfile='certificates/example.key',
107
+ ssl_certfile='certificates/example.crt',
108
  )
109
+
110
+
 
 
 
111
 
112
 
113
  if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser(description='Running demo.')
115
+ parser.add_argument('--port',
116
+ type=int,
117
+ default=7860)
118
+ parser.add_argument('--model_filename',
119
+ type=str,
120
+ default="paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch45.pt")
121
+ parser.add_argument('--config_filename',
122
+ type=str,
123
+ default="paper_replica_10_epoch/config.yaml")
124
+ parser.add_argument('--concurrency_count',
125
+ type=int,
126
+ default=4)
127
+ parser.add_argument('--max_size',
128
+ type=int,
129
+ default=15)
130
+
131
+ args = parser.parse_args()
132
+
133
+
134
+ run_app(args.model_filename, args.config_filename, args.port, args.concurrency_count, args.max_size)