Spaces:
Runtime error
Runtime error
File size: 6,179 Bytes
c62c695 fc851dd 69e7bce fc851dd c62c695 fc851dd c62c695 fc851dd c62c695 fc851dd 93ba80d fc851dd 93ba80d fc851dd 93ba80d fc851dd ccbf851 fc851dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import gradio as gr
import torch
import torchaudio
import hydra
from hydra import compose, initialize
import random
import os
from remfx import effects
cfg = None
classifier = None
models = {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ALL_EFFECTS = effects.Pedalboard_Effects
def init_hydra():
global cfg
initialize(config_path="cfg", job_name="remfx", version_base="2.0")
cfg = compose(config_name="config", overrides=["+exp=remfx_detect"])
def load_models():
global classifier
print("Loading models")
classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
ckpt_path = cfg.classifier_ckpt
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
classifier.load_state_dict(state_dict)
classifier.to(device)
for effect in cfg.ckpts:
model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
ckpt_path = cfg.ckpts[effect].ckpt_path
state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
model.load_state_dict(state_dict)
model.to(device)
models[effect] = model
def audio_classification(audio_file):
audio, sr = torchaudio.load(audio_file)
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
# Convert to mono
audio = audio.mean(0, keepdim=True)
# Add dimension for batch
audio = audio.unsqueeze(0)
audio = audio.to(device)
with torch.no_grad():
# Classify
print("Detecting effects")
labels = torch.tensor(classifier(audio))
labels_dict = {
ALL_EFFECTS[i].__name__.replace("RandomPedalboard", ""): labels[i].item()
for i in range(len(ALL_EFFECTS))
}
return labels_dict
def audio_removal(audio_file, labels, threshold):
audio, sr = torchaudio.load(audio_file)
audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
# Convert to mono
audio = audio.mean(0, keepdim=True)
# Add dimension for batch
audio = audio.unsqueeze(0)
audio = audio.to(device)
label_names = [f"RandomPedalboard{lab['label']}" for lab in labels["confidences"]]
logits = torch.tensor([lab["confidence"] for lab in labels["confidences"]])
rem_fx_labels = torch.where(logits > threshold, 1.0, 0.0)
effects_present = [
name for name, effect in zip(label_names, rem_fx_labels) if effect == 1.0
]
print("Removing effects:", effects_present)
# Remove effects
# Shuffle effects order
effects_order = cfg.inference_effects_ordering
random.shuffle(effects_order)
# Get the correct effect by search for names in effects_order
effects = [effect for effect in effects_order if effect in effects_present]
elem = audio
with torch.no_grad():
for effect in effects:
# Sample the model
elem = models[effect].model.sample(elem)
output = elem.squeeze(0)
waveform = gr.make_waveform((cfg.sample_rate, output[0].numpy()))
return waveform
def ui():
css = """
#classifier {
padding-top: 40px;
}
#classifier .output-class {
display: none;
}
"""
with gr.Blocks(css=css) as interface:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
RemFx: General Purpose Audio Effect Removal
</h1>
</div> <p style="margin-bottom: 10px; font-size: 94%">
<a href="https://arxiv.org/abs/2308.16177">[Paper]</a> <a href="https://csteinmetz1.github.io/RemFX/">[Project
page]</a>
</p>
</div>
"""
)
gr.HTML(
"""
<div style="text-align: left;"> This is our demo for the paper General Purpose Audio Effect Removal. It uses the RemFX Detect system described in the paper to detect the audio effects that are present and remove them. <br>
To use the demo, use one of our curated examples or upload your own audio file and click submit. The system will then detect the effects present in the audio remove them if they meet the threshold. </div>
"""
)
with gr.Row():
with gr.Column():
effected_audio = gr.Audio(
source="upload",
type="filepath",
label="File",
interactive=True,
elem_id="melody-input",
)
submit = gr.Button("Submit")
threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.5,
label="Detection Threshold",
)
with gr.Column():
classifier = gr.Label(
num_top_classes=5, label="Effects Present", elem_id="classifier"
)
audio_output = gr.Video(label="Output")
gr.Examples(
fn=audio_removal,
examples=[
["./input_examples/guitar.wav"],
["./input_examples/vocal.wav"],
["./input_examples/bass.wav"],
["./input_examples/drums.wav"],
["./input_examples/crazy_guitar.wav"],
],
inputs=effected_audio,
)
submit.click(
audio_classification,
inputs=[effected_audio],
outputs=[classifier],
queue=False,
show_progress=False,
).then(
audio_removal,
inputs=[effected_audio, classifier, threshold],
outputs=[audio_output],
)
interface.queue().launch()
if __name__ == "__main__":
init_hydra()
load_models()
ui()
|