File size: 4,878 Bytes
091b1e0
 
 
f6561a5
091b1e0
dc3ceb7
 
 
b80b88c
 
a6a74d4
091b1e0
33534ec
 
a6a74d4
33534ec
0837029
c45e107
 
 
 
 
 
a6a74d4
c45e107
 
a6a74d4
c45e107
 
a6a74d4
c45e107
dc3ceb7
a6a74d4
b80b88c
 
 
 
 
 
dc3ceb7
 
98d175b
a6a74d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c45e107
a6a74d4
 
091b1e0
dc3ceb7
 
a6a74d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import uuid
import ffmpeg
import gradio as gr
from pathlib import Path
from denoisers.SpectralGating import SpectralGating
from huggingface_hub import hf_hub_download
from denoisers.demucs import Demucs
import torch
import torchaudio
import yaml
import argparse

import os
os.environ['CURL_CA_BUNDLE'] = ''
SAMPLE_RATE = 32000


def denoising_transform(audio, model):
    src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4())))
    tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4())))
    src_path.parent.mkdir(exist_ok=True, parents=True)
    tgt_path.parent.mkdir(exist_ok=True, parents=True)
    (ffmpeg.input(audio)
     .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=SAMPLE_RATE)
     .run()
     )
    wav, rate = torchaudio.load(src_path)
    reduced_noise = model.predict(wav)
    torchaudio.save(tgt_path, reduced_noise, rate)
    return src_path, tgt_path


def run_app(model_filename, config_filename, port, concurrency_count, max_size):
    model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename)
    config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename)
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    model = Demucs(config['demucs'])
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])

    title = "Denoising"


    with gr.Blocks(title=title) as app:
        with gr.Row():
            with gr.Column():
                gr.Markdown(
                        """
                    # Denoising
                    ## Instruction: \n
                    1. Press "Record from microphone"
                    2. Press "Stop recording"
                    3. Press "Enhance" \n
                    - You can switch to the tab "File" to upload a prerecorded .wav audio  instead of recording from microphone.
                    """
                    )
                with gr.Tab("Microphone"):
                    microphone = gr.Audio(label="Source Audio", source="microphone", type='filepath')
                    with gr.Row():
                        microphone_button = gr.Button("Enhance", variant="primary")
                with gr.Tab("File"):
                    upload = gr.Audio(label="Upload Audio", source="upload", type='filepath')
                    with gr.Row():
                        upload_button = gr.Button("Enhance", variant="primary")
                clear_btn = gr.Button("Clear")
                gr.Examples(examples=[[path] for path in Path("testing/wavs/").glob("*.wav")],
                           inputs=[microphone, upload])

            with gr.Column():
                outputs = [gr.Audio(label="Input Audio", type='filepath'),
                           gr.Audio(label="Demucs Enhancement", type='filepath'),
                           gr.Audio(label="Spectral Gating Enhancement", type='filepath')
                           ]

        def submit(audio):
            src_path, demucs_tgt_path = denoising_transform(audio, model)
            _, spectral_gating_tgt_path = denoising_transform(audio, SpectralGating())
            return src_path, demucs_tgt_path, spectral_gating_tgt_path, gr.update(visible=False), gr.update(visible=False)
        

        
        microphone_button.click(
            submit,
            microphone,
            outputs + [microphone, upload] 
        )
        upload_button.click(
            submit,
            upload,
            outputs + [microphone, upload]
        )


        def restart():
            return microphone.update(visible=True, value=None), upload.update(visible=True, value=None), None, None, None
        
        clear_btn.click(restart, inputs=[], outputs=[microphone, upload] + outputs)

    app.queue(concurrency_count=concurrency_count, max_size=max_size)
    
    app.launch(
        server_name='0.0.0.0',
        server_port=port,
    )




if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Running demo.')
    parser.add_argument('--port',
                        type=int,
                        default=7860)
    parser.add_argument('--model_filename',
                        type=str,
                        default="paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch45.pt")
    parser.add_argument('--config_filename',
                        type=str,
                        default="paper_replica_10_epoch/config.yaml")
    parser.add_argument('--concurrency_count',
                        type=int,
                        default=4)
    parser.add_argument('--max_size',
                        type=int,
                        default=15)
    
    args = parser.parse_args()
    

    run_app(args.model_filename, args.config_filename, args.port, args.concurrency_count, args.max_size)