kevinwang676 commited on
Commit
fac919d
·
1 Parent(s): 21110de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -0
app.py CHANGED
@@ -68,6 +68,237 @@ device = torch.device("cpu")
68
  if torch.cuda.is_available():
69
  device = torch.device("cuda", 0)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  # VALL-E-X model
72
  model = VALLE(
73
  N_DIM,
 
68
  if torch.cuda.is_available():
69
  device = torch.device("cuda", 0)
70
 
71
+ # Denoise
72
+
73
+ model1, df, _ = init_df("./DeepFilterNet2", config_allow_defaults=True)
74
+ model1 = model1.to(device=device).eval()
75
+
76
+ fig_noisy: plt.Figure
77
+ fig_enh: plt.Figure
78
+ ax_noisy: plt.Axes
79
+ ax_enh: plt.Axes
80
+ fig_noisy, ax_noisy = plt.subplots(figsize=(15.2, 4))
81
+ fig_noisy.set_tight_layout(True)
82
+ fig_enh, ax_enh = plt.subplots(figsize=(15.2, 4))
83
+ fig_enh.set_tight_layout(True)
84
+
85
+ NOISES = {
86
+ "None": None,
87
+ }
88
+
89
+ def mix_at_snr(clean, noise, snr, eps=1e-10):
90
+ """Mix clean and noise signal at a given SNR.
91
+ Args:
92
+ clean: 1D Tensor with the clean signal to mix.
93
+ noise: 1D Tensor of shape.
94
+ snr: Signal to noise ratio.
95
+ Returns:
96
+ clean: 1D Tensor with gain changed according to the snr.
97
+ noise: 1D Tensor with the combined noise channels.
98
+ mix: 1D Tensor with added clean and noise signals.
99
+ """
100
+ clean = torch.as_tensor(clean).mean(0, keepdim=True)
101
+ noise = torch.as_tensor(noise).mean(0, keepdim=True)
102
+ if noise.shape[1] < clean.shape[1]:
103
+ noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1]))))
104
+ max_start = int(noise.shape[1] - clean.shape[1])
105
+ start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0
106
+ logger.debug(f"start: {start}, {clean.shape}")
107
+ noise = noise[:, start : start + clean.shape[1]]
108
+ E_speech = torch.mean(clean.pow(2)) + eps
109
+ E_noise = torch.mean(noise.pow(2))
110
+ K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps)
111
+ noise = noise / K
112
+ mixture = clean + noise
113
+ logger.debug("mixture: {mixture.shape}")
114
+ assert torch.isfinite(mixture).all()
115
+ max_m = mixture.abs().max()
116
+ if max_m > 1:
117
+ logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m}")
118
+ clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m
119
+ return clean, noise, mixture
120
+
121
+
122
+ def load_audio_gradio(
123
+ audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int
124
+ ) -> Optional[Tuple[Tensor, AudioMetaData]]:
125
+ if audio_or_file is None:
126
+ return None
127
+ if isinstance(audio_or_file, str):
128
+ if audio_or_file.lower() == "none":
129
+ return None
130
+ # First try default format
131
+ audio, meta = load_audio(audio_or_file, sr)
132
+ else:
133
+ meta = AudioMetaData(-1, -1, -1, -1, "")
134
+ assert isinstance(audio_or_file, (tuple, list))
135
+ meta.sample_rate, audio_np = audio_or_file
136
+ # Gradio documentation says, the shape is [samples, 2], but apparently sometimes its not.
137
+ audio_np = audio_np.reshape(audio_np.shape[0], -1).T
138
+ if audio_np.dtype == np.int16:
139
+ audio_np = (audio_np / (1 << 15)).astype(np.float32)
140
+ elif audio_np.dtype == np.int32:
141
+ audio_np = (audio_np / (1 << 31)).astype(np.float32)
142
+ audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr)
143
+ return audio, meta
144
+
145
+
146
+ def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: str):
147
+ if mic_input:
148
+ speech_upl = mic_input
149
+ sr = config("sr", 48000, int, section="df")
150
+ logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}")
151
+ snr = int(snr)
152
+ noise_fn = NOISES[noise_type]
153
+ meta = AudioMetaData(-1, -1, -1, -1, "")
154
+ max_s = 10 # limit to 10 seconds
155
+ if speech_upl is not None:
156
+ sample, meta = load_audio(speech_upl, sr)
157
+ max_len = max_s * sr
158
+ if sample.shape[-1] > max_len:
159
+ start = torch.randint(0, sample.shape[-1] - max_len, ()).item()
160
+ sample = sample[..., start : start + max_len]
161
+ else:
162
+ sample, meta = load_audio("samples/p232_013_clean.wav", sr)
163
+ sample = sample[..., : max_s * sr]
164
+ if sample.dim() > 1 and sample.shape[0] > 1:
165
+ assert (
166
+ sample.shape[1] > sample.shape[0]
167
+ ), f"Expecting channels first, but got {sample.shape}"
168
+ sample = sample.mean(dim=0, keepdim=True)
169
+ logger.info(f"Loaded sample with shape {sample.shape}")
170
+ if noise_fn is not None:
171
+ noise, _ = load_audio(noise_fn, sr) # type: ignore
172
+ logger.info(f"Loaded noise with shape {noise.shape}")
173
+ _, _, sample = mix_at_snr(sample, noise, snr)
174
+ logger.info("Start denoising audio")
175
+ enhanced = enhance(model1, df, sample)
176
+ logger.info("Denoising finished")
177
+ lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0)
178
+ lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1)
179
+ enhanced = enhanced * lim
180
+ if meta.sample_rate != sr:
181
+ enhanced = resample(enhanced, sr, meta.sample_rate)
182
+ sample = resample(sample, sr, meta.sample_rate)
183
+ sr = meta.sample_rate
184
+ noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name
185
+ save_audio(noisy_wav, sample, sr)
186
+ enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name
187
+ save_audio(enhanced_wav, enhanced, sr)
188
+ logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}")
189
+ ax_noisy.clear()
190
+ ax_enh.clear()
191
+ noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy)
192
+ enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh)
193
+ # noisy_wav = gr.make_waveform(noisy_fn, bar_count=200)
194
+ # enh_wav = gr.make_waveform(enhanced_fn, bar_count=200)
195
+ return noisy_wav, noisy_im, enhanced_wav, enh_im
196
+
197
+
198
+ def specshow(
199
+ spec,
200
+ ax=None,
201
+ title=None,
202
+ xlabel=None,
203
+ ylabel=None,
204
+ sr=48000,
205
+ n_fft=None,
206
+ hop=None,
207
+ t=None,
208
+ f=None,
209
+ vmin=-100,
210
+ vmax=0,
211
+ xlim=None,
212
+ ylim=None,
213
+ cmap="inferno",
214
+ ):
215
+ """Plots a spectrogram of shape [F, T]"""
216
+ spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec
217
+ if ax is not None:
218
+ set_title = ax.set_title
219
+ set_xlabel = ax.set_xlabel
220
+ set_ylabel = ax.set_ylabel
221
+ set_xlim = ax.set_xlim
222
+ set_ylim = ax.set_ylim
223
+ else:
224
+ ax = plt
225
+ set_title = plt.title
226
+ set_xlabel = plt.xlabel
227
+ set_ylabel = plt.ylabel
228
+ set_xlim = plt.xlim
229
+ set_ylim = plt.ylim
230
+ if n_fft is None:
231
+ if spec.shape[0] % 2 == 0:
232
+ n_fft = spec.shape[0] * 2
233
+ else:
234
+ n_fft = (spec.shape[0] - 1) * 2
235
+ hop = hop or n_fft // 4
236
+ if t is None:
237
+ t = np.arange(0, spec_np.shape[-1]) * hop / sr
238
+ if f is None:
239
+ f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000
240
+ im = ax.pcolormesh(
241
+ t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap
242
+ )
243
+ if title is not None:
244
+ set_title(title)
245
+ if xlabel is not None:
246
+ set_xlabel(xlabel)
247
+ if ylabel is not None:
248
+ set_ylabel(ylabel)
249
+ if xlim is not None:
250
+ set_xlim(xlim)
251
+ if ylim is not None:
252
+ set_ylim(ylim)
253
+ return im
254
+
255
+
256
+ def spec_im(
257
+ audio: torch.Tensor,
258
+ figsize=(15, 5),
259
+ colorbar=False,
260
+ colorbar_format=None,
261
+ figure=None,
262
+ labels=True,
263
+ **kwargs,
264
+ ) -> Image:
265
+ audio = torch.as_tensor(audio)
266
+ if labels:
267
+ kwargs.setdefault("xlabel", "Time [s]")
268
+ kwargs.setdefault("ylabel", "Frequency [Hz]")
269
+ n_fft = kwargs.setdefault("n_fft", 1024)
270
+ hop = kwargs.setdefault("hop", 512)
271
+ w = torch.hann_window(n_fft, device=audio.device)
272
+ spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False)
273
+ spec = spec.div_(w.pow(2).sum())
274
+ spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10)
275
+ kwargs.setdefault("vmax", max(0.0, spec.max().item()))
276
+
277
+ if figure is None:
278
+ figure = plt.figure(figsize=figsize)
279
+ figure.set_tight_layout(True)
280
+ if spec.dim() > 2:
281
+ spec = spec.squeeze(0)
282
+ im = specshow(spec, **kwargs)
283
+ if colorbar:
284
+ ckwargs = {}
285
+ if "ax" in kwargs:
286
+ if colorbar_format is None:
287
+ if kwargs.get("vmin", None) is not None or kwargs.get("vmax", None) is not None:
288
+ colorbar_format = "%+2.0f dB"
289
+ ckwargs = {"ax": kwargs["ax"]}
290
+ plt.colorbar(im, format=colorbar_format, **ckwargs)
291
+ figure.canvas.draw()
292
+ return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
293
+
294
+
295
+ def toggle(choice):
296
+ if choice == "mic":
297
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
298
+ else:
299
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
300
+
301
+
302
  # VALL-E-X model
303
  model = VALLE(
304
  N_DIM,