Spaces:
Runtime error
Runtime error
Commit
·
fc851dd
1
Parent(s):
c62c695
Updated all for hf spaces
Browse files
README.md
CHANGED
@@ -1,3 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
<div align="center">
|
2 |
|
3 |
# RemFx
|
|
|
1 |
+
---
|
2 |
+
title: RemFx
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 3.41.2
|
6 |
+
---
|
7 |
<div align="center">
|
8 |
|
9 |
# RemFx
|
app.py
CHANGED
@@ -1,9 +1,181 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
return "Hello " + name + "!!"
|
6 |
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import hydra
|
5 |
+
from hydra import compose, initialize
|
6 |
+
import random
|
7 |
+
from remfx import effects
|
8 |
|
9 |
+
cfg = None
|
10 |
+
classifier = None
|
11 |
+
models = {}
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
|
14 |
+
ALL_EFFECTS = effects.Pedalboard_Effects
|
|
|
15 |
|
16 |
|
17 |
+
def init_hydra():
|
18 |
+
global cfg
|
19 |
+
initialize(config_path="cfg", job_name="remfx", version_base="2.0")
|
20 |
+
cfg = compose(config_name="config", overrides=["+exp=remfx_detect"])
|
21 |
+
|
22 |
+
|
23 |
+
def load_models():
|
24 |
+
global classifier
|
25 |
+
print("Loading models")
|
26 |
+
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
|
27 |
+
ckpt_path = cfg.classifier_ckpt
|
28 |
+
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
29 |
+
classifier.load_state_dict(state_dict)
|
30 |
+
classifier.to(device)
|
31 |
+
|
32 |
+
for effect in cfg.ckpts:
|
33 |
+
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
|
34 |
+
ckpt_path = cfg.ckpts[effect].ckpt_path
|
35 |
+
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
|
36 |
+
model.load_state_dict(state_dict)
|
37 |
+
model.to(device)
|
38 |
+
models[effect] = model
|
39 |
+
|
40 |
+
|
41 |
+
def audio_classification(audio_file):
|
42 |
+
audio, sr = torchaudio.load(audio_file)
|
43 |
+
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
|
44 |
+
# Add dimension for batch
|
45 |
+
audio = audio.unsqueeze(0)
|
46 |
+
# Convert to mono
|
47 |
+
audio = audio.mean(0, keepdim=True)
|
48 |
+
audio = audio.to(device)
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
# Classifiy
|
52 |
+
print("Detecting effects")
|
53 |
+
labels = torch.tensor(classifier(audio))
|
54 |
+
labels_dict = {
|
55 |
+
ALL_EFFECTS[i].__name__.replace("RandomPedalboard", ""): labels[i].item()
|
56 |
+
for i in range(len(ALL_EFFECTS))
|
57 |
+
}
|
58 |
+
return labels_dict
|
59 |
+
|
60 |
+
|
61 |
+
def audio_removal(audio_file, labels, threshold):
|
62 |
+
audio, sr = torchaudio.load(audio_file)
|
63 |
+
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
|
64 |
+
# Add dimension for batch
|
65 |
+
audio = audio.unsqueeze(0)
|
66 |
+
# Convert to mono
|
67 |
+
audio = audio.mean(0, keepdim=True)
|
68 |
+
audio = audio.to(device)
|
69 |
+
|
70 |
+
label_names = [f"RandomPedalboard{lab['label']}" for lab in labels["confidences"]]
|
71 |
+
logits = torch.tensor([lab["confidence"] for lab in labels["confidences"]])
|
72 |
+
rem_fx_labels = torch.where(logits > threshold, 1.0, 0.0)
|
73 |
+
effects_present = [
|
74 |
+
name for name, effect in zip(label_names, rem_fx_labels) if effect == 1.0
|
75 |
+
]
|
76 |
+
print("Removing effects:", effects_present)
|
77 |
+
# Remove effects
|
78 |
+
# Shuffle effects order
|
79 |
+
effects_order = cfg.inference_effects_ordering
|
80 |
+
random.shuffle(effects_order)
|
81 |
+
# Get the correct effect by search for names in effects_order
|
82 |
+
effects = [effect for effect in effects_order if effect in effects_present]
|
83 |
+
elem = audio
|
84 |
+
with torch.no_grad():
|
85 |
+
for effect in effects:
|
86 |
+
# Sample the model
|
87 |
+
elem = models[effect].model.sample(elem)
|
88 |
+
output = elem.squeeze(0)
|
89 |
+
waveform = gr.make_waveform((cfg.sample_rate, output[0].numpy()))
|
90 |
+
|
91 |
+
return waveform
|
92 |
+
|
93 |
+
|
94 |
+
def ui():
|
95 |
+
css = """
|
96 |
+
|
97 |
+
#classifier {
|
98 |
+
padding-top: 40px;
|
99 |
+
}
|
100 |
+
#classifier .output-class {
|
101 |
+
display: none;
|
102 |
+
|
103 |
+
}
|
104 |
+
"""
|
105 |
+
with gr.Blocks(css=css) as interface:
|
106 |
+
gr.HTML(
|
107 |
+
"""
|
108 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
109 |
+
<div
|
110 |
+
style="
|
111 |
+
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
|
112 |
+
"
|
113 |
+
>
|
114 |
+
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
|
115 |
+
RemFx: General Purpose Audio Effect Removal
|
116 |
+
</h1>
|
117 |
+
</div> <p style="margin-bottom: 10px; font-size: 94%">
|
118 |
+
<a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://csteinmetz1.github.io/RemFX/">[Project
|
119 |
+
page]</a>
|
120 |
+
</p>
|
121 |
+
</div>
|
122 |
+
"""
|
123 |
+
)
|
124 |
+
gr.HTML(
|
125 |
+
"""
|
126 |
+
<div style="text-align: left;"> This is our demo for the paper General Purpose Audio Effect Removal. It uses the RemFX Detect system described in the paper to detect the audio effects that are present and remove them. <br>
|
127 |
+
To use the demo, use one of our curated examples or upload your own audio file and click submit. The system will then detect the effects present in the audio remove them if they meet the threshold. </div>
|
128 |
+
"""
|
129 |
+
)
|
130 |
+
with gr.Row():
|
131 |
+
with gr.Column():
|
132 |
+
effected_audio = gr.Audio(
|
133 |
+
source="upload",
|
134 |
+
type="filepath",
|
135 |
+
label="File",
|
136 |
+
interactive=True,
|
137 |
+
elem_id="melody-input",
|
138 |
+
)
|
139 |
+
submit = gr.Button("Submit")
|
140 |
+
threshold = gr.Slider(
|
141 |
+
minimum=0.0,
|
142 |
+
maximum=1.0,
|
143 |
+
step=0.1,
|
144 |
+
value=0.5,
|
145 |
+
label="Detection Threshold",
|
146 |
+
)
|
147 |
+
with gr.Column():
|
148 |
+
classifier = gr.Label(
|
149 |
+
num_top_classes=5, label="Effects Present", elem_id="classifier"
|
150 |
+
)
|
151 |
+
audio_output = gr.Video(label="Output")
|
152 |
+
gr.Examples(
|
153 |
+
fn=audio_removal,
|
154 |
+
examples=[
|
155 |
+
["./input_examples/guitar.wav"],
|
156 |
+
["./input_examples/vocal.wav"],
|
157 |
+
["./input_examples/bass.wav"],
|
158 |
+
["./input_examples/drums.wav"],
|
159 |
+
["./input_examples/crazy_guitar.wav"],
|
160 |
+
],
|
161 |
+
inputs=effected_audio,
|
162 |
+
)
|
163 |
+
submit.click(
|
164 |
+
audio_classification,
|
165 |
+
inputs=[effected_audio],
|
166 |
+
outputs=[classifier],
|
167 |
+
queue=False,
|
168 |
+
show_progress=False,
|
169 |
+
).then(
|
170 |
+
audio_removal,
|
171 |
+
inputs=[effected_audio, classifier, threshold],
|
172 |
+
outputs=[audio_output],
|
173 |
+
)
|
174 |
+
|
175 |
+
interface.queue().launch()
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
init_hydra()
|
180 |
+
load_models()
|
181 |
+
ui()
|
setup.py
CHANGED
@@ -53,6 +53,7 @@ setup(
|
|
53 |
"torchmetrics>=1.0",
|
54 |
"wav2clip_hear @ git+https://github.com/hohsiangwu/wav2clip-hear.git",
|
55 |
"panns_hear @ git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs",
|
|
|
56 |
],
|
57 |
include_package_data=True,
|
58 |
license="Apache License 2.0",
|
|
|
53 |
"torchmetrics>=1.0",
|
54 |
"wav2clip_hear @ git+https://github.com/hohsiangwu/wav2clip-hear.git",
|
55 |
"panns_hear @ git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs",
|
56 |
+
"gradio",
|
57 |
],
|
58 |
include_package_data=True,
|
59 |
license="Apache License 2.0",
|