Spaces:
Running
Running
import os | |
import json | |
import argparse | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import tqdm | |
import librosa | |
import librosa.display | |
import soundfile as sf | |
import pyloudnorm as pyln | |
from dotmap import DotMap | |
import gradio as gr | |
from models import load_model_with_args | |
from separate_func import ( | |
conv_tasnet_separate, | |
) | |
from utils import db2linear | |
tqdm.monitor_interval = 0 | |
def separate_track_with_model( | |
args, model, device, track_audio, track_name, meter, augmented_gain | |
): | |
with torch.no_grad(): | |
if ( | |
args.model_loss_params.architecture == "conv_tasnet_mask_on_output" | |
or args.model_loss_params.architecture == "conv_tasnet" | |
): | |
estimates = conv_tasnet_separate( | |
args, | |
model, | |
device, | |
track_audio, | |
track_name, | |
meter=meter, | |
augmented_gain=augmented_gain, | |
) | |
return estimates | |
def main(input, mix_coefficient): | |
parser = argparse.ArgumentParser(description="model test.py") | |
parser.add_argument("--target", type=str, default="all") | |
parser.add_argument("--weight_directory", type=str, default="weight") | |
parser.add_argument("--output_directory", type=str, default="output") | |
parser.add_argument("--use_gpu", type=bool, default=True) | |
parser.add_argument("--save_name_as_target", type=bool, default=False) | |
parser.add_argument( | |
"--loudnorm_input_lufs", | |
type=float, | |
default=None, | |
help="If you want to use loudnorm for input", | |
) | |
parser.add_argument( | |
"--save_output_loudnorm", | |
type=float, | |
default=-14.0, | |
help="Save loudness normalized outputs or not. If you want to save, input target loudness", | |
) | |
parser.add_argument( | |
"--save_mixed_output", | |
type=float, | |
default=None, | |
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)", | |
) | |
parser.add_argument( | |
"--save_16k_mono", | |
type=bool, | |
default=False, | |
help="Save 16k mono wav files for FAD evaluation.", | |
) | |
parser.add_argument( | |
"--save_histogram", | |
type=bool, | |
default=False, | |
help="Save histogram of the output. Only valid when the task is 'delimit'", | |
) | |
parser.add_argument( | |
"--use_singletrackset", | |
type=bool, | |
default=False, | |
help="Use SingleTrackSet if input data is too long.", | |
) | |
args, _ = parser.parse_known_args() | |
with open(f"{args.weight_directory}/{args.target}.json", "r") as f: | |
args_dict = json.load(f) | |
args_dict = DotMap(args_dict) | |
for key, value in args_dict["args"].items(): | |
if key in list(vars(args).keys()): | |
pass | |
else: | |
setattr(args, key, value) | |
args.test_output_dir = f"{args.output_directory}" | |
os.makedirs(args.test_output_dir, exist_ok=True) | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu" | |
) | |
###################### Define Models ###################### | |
our_model = load_model_with_args(args) | |
our_model = our_model.to(device) | |
target_model_path = f"{args.weight_directory}/{args.target}.pth" | |
checkpoint = torch.load(target_model_path, map_location=device) | |
our_model.load_state_dict(checkpoint) | |
our_model.eval() | |
meter = pyln.Meter(44100) | |
sr, track_audio = input | |
track_audio = track_audio.T | |
track_name = "gradio_demo" | |
orig_audio = track_audio.copy() | |
if sr != 44100: | |
raise ValueError("Sample rate should be 44100") | |
augmented_gain = None | |
if args.loudnorm_input_lufs: # If you want to use loud-normalized input | |
track_lufs = meter.integrated_loudness(track_audio.T) | |
augmented_gain = args.loudnorm_input_lufs - track_lufs | |
track_audio = track_audio * db2linear(augmented_gain, eps=0.0) | |
track_audio = ( | |
torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device) | |
) | |
estimates = separate_track_with_model( | |
args, our_model, device, track_audio, track_name, meter, augmented_gain | |
) | |
if args.save_mixed_output: | |
track_lufs = meter.integrated_loudness(orig_audio.T) | |
augmented_gain = args.save_output_loudnorm - track_lufs | |
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0) | |
mixed_output = orig_audio * args.save_mixed_output + estimates * ( | |
1 - args.save_mixed_output | |
) | |
sf.write( | |
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav", | |
mixed_output.T, | |
args.data_params.sample_rate, | |
) | |
return ( | |
(sr, estimates.T), | |
(sr, orig_audio.T), | |
(sr, orig_audio.T * mix_coefficient + estimates.T * (1 - mix_coefficient)), | |
) | |
def parallel_mix(input, output, mix_coefficient): | |
sr = 44100 | |
return sr, input[1] * mix_coefficient + output[1] * (1 - mix_coefficient) | |
def int16_to_float32(wav): | |
wav = np.frombuffer(wav, dtype=np.int16) | |
X = wav / 32768 | |
return X | |
def waveform_plot(input, output, prl_mix_ouptut, figsize_x=20, figsize_y=9): | |
sr = 44100 | |
fig, ax = plt.subplots( | |
nrows=3, sharex=True, sharey=True, figsize=(figsize_x, figsize_y) | |
) | |
librosa.display.waveshow(int16_to_float32(input[1]).T, sr=sr, ax=ax[0]) | |
ax[0].set(title="Loudness Normalized Input") | |
ax[0].label_outer() | |
librosa.display.waveshow(int16_to_float32(output[1]).T, sr=sr, ax=ax[1]) | |
ax[1].set(title="De-limiter Output") | |
ax[1].label_outer() | |
librosa.display.waveshow(int16_to_float32(prl_mix_ouptut[1]).T, sr=sr, ax=ax[2]) | |
ax[2].set(title="Parallel Mix of the Input and its De-limiter Output") | |
ax[2].label_outer() | |
return fig | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 700px; margin: 0 auto;"> | |
<div | |
style=" | |
display: inline-flex; | |
align-items: center; | |
gap: 0.8rem; | |
font-size: 1.75rem; | |
" | |
> | |
<h1 style="font-weight: 900; margin-bottom: 7px;"> | |
Music De-limiter | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 94%"> | |
A demo for "Music De-limiter via Sample-wise Gain Inversion" to appear in WASPAA 2023. | |
You can first upload a music (.wav or .mp3) file and then press "De-limit" button to apply the De-limiter. Since we use a CPU instead of a GPU, it may require a few minute. | |
Then, you can apply a Parallel Mix technique, which is a simple linear mixing technique of "loudness normalized input" and the "de-limiter output". | |
You can modify the mixing coefficient by yourself. | |
If the coefficient is 0.3 then the output will be the "loudness_normalized_input * 0.3 + de-limiter_output * 0.7" | |
</div> | |
""" | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
with gr.Column(): | |
with gr.Box(): | |
input_audio = gr.Audio(source="upload", label="De-limiter Input") | |
btn = gr.Button("De-limit") | |
with gr.Column(): | |
with gr.Box(): | |
loud_norm_input = gr.Audio(label="Loudness Normalized Input (-14LUFS)") | |
with gr.Box(): | |
output_audio = gr.Audio(label="De-limiter Output") | |
with gr.Box(): | |
output_audio_parallel = gr.Audio( | |
label="Parallel Mix of the Input and its De-limiter Output" | |
) | |
slider = gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.5, | |
label="Parallel Mix Coefficient", | |
) | |
btn.click( | |
main, | |
inputs=[input_audio, slider], | |
outputs=[output_audio, loud_norm_input, output_audio_parallel], | |
) | |
slider.release( | |
parallel_mix, | |
inputs=[input_audio, output_audio, slider], | |
outputs=output_audio_parallel, | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
with gr.Column(): | |
with gr.Box(): | |
plot = gr.Plot(label="Plots") | |
btn2 = gr.Button("Show Plots") | |
slider_plot_x = gr.Slider( | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=20, | |
label="Plot X-axis size", | |
) | |
slider_plot_y = gr.Slider( | |
minimum=1, | |
maximum=30, | |
step=1, | |
value=9, | |
label="Plot Y-axis size", | |
) | |
btn2.click( | |
waveform_plot, | |
inputs=[ | |
loud_norm_input, | |
output_audio, | |
output_audio_parallel, | |
slider_plot_x, | |
slider_plot_y, | |
], | |
outputs=plot, | |
) | |
slider.release( | |
waveform_plot, | |
inputs=[ | |
loud_norm_input, | |
output_audio, | |
output_audio_parallel, | |
slider_plot_x, | |
slider_plot_y, | |
], | |
outputs=plot, | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |