Spaces:
Running
Running
import os | |
import json | |
import argparse | |
import glob | |
import torch | |
import tqdm | |
import librosa | |
import soundfile as sf | |
import pyloudnorm as pyln | |
from dotmap import DotMap | |
from models import load_model_with_args | |
from separate_func import ( | |
conv_tasnet_separate, | |
) | |
from utils import str2bool, db2linear | |
tqdm.monitor_interval = 0 | |
def separate_track_with_model( | |
args, model, device, track_audio, track_name, meter, augmented_gain | |
): | |
with torch.no_grad(): | |
if ( | |
args.model_loss_params.architecture == "conv_tasnet_mask_on_output" | |
or args.model_loss_params.architecture == "conv_tasnet" | |
): | |
estimates = conv_tasnet_separate( | |
args, | |
model, | |
device, | |
track_audio, | |
track_name, | |
meter=meter, | |
augmented_gain=augmented_gain, | |
) | |
return estimates | |
def main(): | |
parser = argparse.ArgumentParser(description="model test.py") | |
parser.add_argument("--target", type=str, default="all") | |
parser.add_argument("--data_root", type=str, default="./input_data") | |
parser.add_argument("--weight_directory", type=str, default="./weight") | |
parser.add_argument("--output_directory", type=str, default="./output") | |
parser.add_argument("--use_gpu", type=str2bool, default=True) | |
parser.add_argument("--save_name_as_target", type=str2bool, default=False) | |
parser.add_argument( | |
"--loudnorm_input_lufs", | |
type=float, | |
default=None, | |
help="If you want to use loudnorm for input", | |
) | |
parser.add_argument( | |
"--save_output_loudnorm", | |
type=float, | |
default=-14.0, | |
help="Save loudness normalized outputs or not. If you want to save, input target loudness", | |
) | |
parser.add_argument( | |
"--save_mixed_output", | |
type=float, | |
default=None, | |
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)", | |
) | |
parser.add_argument( | |
"--save_16k_mono", | |
type=str2bool, | |
default=False, | |
help="Save 16k mono wav files for FAD evaluation.", | |
) | |
parser.add_argument( | |
"--save_histogram", | |
type=str2bool, | |
default=False, | |
help="Save histogram of the output. Only valid when the task is 'delimit'", | |
) | |
parser.add_argument( | |
"--use_singletrackset", | |
type=str2bool, | |
default=False, | |
help="Use SingleTrackSet if input data is too long.", | |
) | |
args, _ = parser.parse_known_args() | |
with open(f"{args.weight_directory}/{args.target}.json", "r") as f: | |
args_dict = json.load(f) | |
args_dict = DotMap(args_dict) | |
for key, value in args_dict["args"].items(): | |
if key in list(vars(args).keys()): | |
pass | |
else: | |
setattr(args, key, value) | |
args.test_output_dir = f"{args.output_directory}" | |
os.makedirs(args.test_output_dir, exist_ok=True) | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu" | |
) | |
###################### Define Models ###################### | |
our_model = load_model_with_args(args) | |
our_model = our_model.to(device) | |
target_model_path = f"{args.weight_directory}/{args.target}.pth" | |
checkpoint = torch.load(target_model_path, map_location=device) | |
our_model.load_state_dict(checkpoint) | |
our_model.eval() | |
meter = pyln.Meter(44100) | |
test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob( | |
f"{args.data_root}/*.mp3" | |
) | |
for track in tqdm.tqdm(test_tracks): | |
track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "") | |
track_audio, sr = librosa.load(track, sr=None, mono=False) # sr should be 44100 | |
orig_audio = track_audio.copy() | |
if sr != 44100: | |
raise ValueError("Sample rate should be 44100") | |
augmented_gain = None | |
print("Now De-limiting : ", track_name) | |
if args.loudnorm_input_lufs: # If you want to use loud-normalized input | |
track_lufs = meter.integrated_loudness(track_audio.T) | |
augmented_gain = args.loudnorm_input_lufs - track_lufs | |
track_audio = track_audio * db2linear(augmented_gain, eps=0.0) | |
track_audio = ( | |
torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device) | |
) | |
estimates = separate_track_with_model( | |
args, our_model, device, track_audio, track_name, meter, augmented_gain | |
) | |
if args.save_mixed_output: | |
track_lufs = meter.integrated_loudness(orig_audio.T) | |
augmented_gain = args.save_output_loudnorm - track_lufs | |
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0) | |
mixed_output = orig_audio * args.save_mixed_output + estimates * ( | |
1 - args.save_mixed_output | |
) | |
sf.write( | |
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav", | |
mixed_output.T, | |
args.data_params.sample_rate, | |
) | |
if __name__ == "__main__": | |
main() | |