denoising / app.py
BorisovMaksim's picture
Update app.py
98d175b
import uuid
import ffmpeg
import gradio as gr
from pathlib import Path
from denoisers.SpectralGating import SpectralGating
from huggingface_hub import hf_hub_download
from denoisers.demucs import Demucs
import torch
import torchaudio
import yaml
import argparse
import os
os.environ['CURL_CA_BUNDLE'] = ''
SAMPLE_RATE = 32000
def denoising_transform(audio, model):
src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4())))
src_path.parent.mkdir(exist_ok=True, parents=True)
tgt_path.parent.mkdir(exist_ok=True, parents=True)
(ffmpeg.input(audio)
.output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=SAMPLE_RATE)
.run()
)
wav, rate = torchaudio.load(src_path)
reduced_noise = model.predict(wav)
torchaudio.save(tgt_path, reduced_noise, rate)
return src_path, tgt_path
def run_app(model_filename, config_filename, port, concurrency_count, max_size):
model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename)
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
model = Demucs(config['demucs'])
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
title = "Denoising"
with gr.Blocks(title=title) as app:
with gr.Row():
with gr.Column():
gr.Markdown(
"""
# Denoising
## Instruction: \n
1. Press "Record from microphone"
2. Press "Stop recording"
3. Press "Enhance" \n
- You can switch to the tab "File" to upload a prerecorded .wav audio instead of recording from microphone.
"""
)
with gr.Tab("Microphone"):
microphone = gr.Audio(label="Source Audio", source="microphone", type='filepath')
with gr.Row():
microphone_button = gr.Button("Enhance", variant="primary")
with gr.Tab("File"):
upload = gr.Audio(label="Upload Audio", source="upload", type='filepath')
with gr.Row():
upload_button = gr.Button("Enhance", variant="primary")
clear_btn = gr.Button("Clear")
gr.Examples(examples=[[path] for path in Path("testing/wavs/").glob("*.wav")],
inputs=[microphone, upload])
with gr.Column():
outputs = [gr.Audio(label="Input Audio", type='filepath'),
gr.Audio(label="Demucs Enhancement", type='filepath'),
gr.Audio(label="Spectral Gating Enhancement", type='filepath')
]
def submit(audio):
src_path, demucs_tgt_path = denoising_transform(audio, model)
_, spectral_gating_tgt_path = denoising_transform(audio, SpectralGating())
return src_path, demucs_tgt_path, spectral_gating_tgt_path, gr.update(visible=False), gr.update(visible=False)
microphone_button.click(
submit,
microphone,
outputs + [microphone, upload]
)
upload_button.click(
submit,
upload,
outputs + [microphone, upload]
)
def restart():
return microphone.update(visible=True, value=None), upload.update(visible=True, value=None), None, None, None
clear_btn.click(restart, inputs=[], outputs=[microphone, upload] + outputs)
app.queue(concurrency_count=concurrency_count, max_size=max_size)
app.launch(
server_name='0.0.0.0',
server_port=port,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Running demo.')
parser.add_argument('--port',
type=int,
default=7860)
parser.add_argument('--model_filename',
type=str,
default="paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch45.pt")
parser.add_argument('--config_filename',
type=str,
default="paper_replica_10_epoch/config.yaml")
parser.add_argument('--concurrency_count',
type=int,
default=4)
parser.add_argument('--max_size',
type=int,
default=15)
args = parser.parse_args()
run_app(args.model_filename, args.config_filename, args.port, args.concurrency_count, args.max_size)