File size: 3,302 Bytes
fbe31d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33171a6
 
 
 
 
 
 
 
 
 
 
 
 
 
fbe31d2
 
 
 
 
 
 
e3c7365
 
 
 
 
fbe31d2
 
 
 
33171a6
fbe31d2
 
 
 
 
 
 
 
 
 
33171a6
 
 
 
fbe31d2
 
 
 
 
 
 
 
 
 
 
33171a6
fbe31d2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
from model.CLAPSep import CLAPSep
import torch
import librosa
import numpy as np


model_config = {"lan_embed_dim": 1024,
    "depths": [1, 1, 1, 1],
    "embed_dim": 128,
    "encoder_embed_dim": 128,
    "phase": False,
    "spec_factor": 8,
    "d_attn": 640,
    "n_masker_layer": 3,
    "conv": False}
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLAP_path = "model/music_audioset_epoch_15_esc_90.14.pt"


model = CLAPSep(model_config, CLAP_path).to(DEVICE)
ckpt = torch.load('model/best_model.ckpt', map_location=DEVICE)
model.load_state_dict(ckpt, strict=False)
model.eval()



def inference(audio_file_path: str, text_p: str, audio_file_path_p: str, text_n: str, audio_file_path_n: str):
    # handling queries
    with torch.no_grad():
        embed_pos, embed_neg = torch.chunk(model.clap_model.get_text_embedding([text_p, text_n],
                                                                              use_tensor=True), dim=0, chunks=2)
        embed_pos = torch.zeros_like(embed_pos) if text_p == '' else embed_pos
        embed_neg = torch.zeros_like(embed_neg) if text_n == '' else embed_neg
        embed_pos += (model.clap_model.get_audio_embedding_from_filelist(
            [audio_file_path_p]) if audio_file_path_p is not None else torch.zeros_like(embed_pos))
        embed_neg += (model.clap_model.get_audio_embedding_from_filelist(
            [audio_file_path_n]) if audio_file_path_n is not None else torch.zeros_like(embed_neg))



    print(f"Separate audio from [{audio_file_path}] with textual query p: [{text_p}] and n: [{text_n}]")

    mixture, _ = librosa.load(audio_file_path, sr=32000)

    pad = (320000 - (len(mixture) % 320000))if len(mixture) % 320000 != 0 else 0

    mixture =torch.tensor(np.pad(mixture,(0,pad)))
    
    max_value = torch.max(torch.abs(mixture))
    if max_value > 1:
        mixture *= 0.9 / max_value
    
    mixture_chunks = torch.chunk(mixture, dim=0, chunks=len(mixture)//320000)
    sep_segments = []
    for chunk in mixture_chunks:
        with torch.no_grad():
            sep_segments.append(model.inference_from_data(chunk.unsqueeze(0), embed_pos, embed_neg))

    sep_segment = torch.concat(sep_segments, dim=1)

    return 32000, sep_segment.squeeze().numpy()


with gr.Blocks(title="CLAPSep") as demo:
    with gr.Row():
        with gr.Column():
            input_audio = gr.Audio(label="Mixture", type="filepath")
            text_p = gr.Textbox(label="Positive Query Text")
            text_n = gr.Textbox(label="Negative Query Text")
            query_audio_p = gr.Audio(label="Positive Query Audio (optional)", type="filepath")
            query_audio_n = gr.Audio(label="Negative Query Audio (optional)", type="filepath")
        with gr.Column():
            with gr.Column():
                output_audio = gr.Audio(label="Separation Result", scale=10)
                button = gr.Button(
                    "Separate",
                    variant="primary",
                    scale=2,
                    size="lg",
                    interactive=True,
                )
                button.click(
                    fn=inference, inputs=[input_audio, text_p, query_audio_p, text_n, query_audio_n], outputs=[output_audio]
                )


demo.queue().launch(share=True)