Spaces:
Running
Running
# To be honest... this is not ddp. | |
import os | |
import json | |
import argparse | |
import glob | |
import torch | |
import tqdm | |
import musdb | |
import librosa | |
import soundfile as sf | |
import pyloudnorm as pyln | |
from dotmap import DotMap | |
from models import load_model_with_args | |
from separate_func import ( | |
conv_tasnet_separate, | |
) | |
from utils import str2bool, db2linear | |
tqdm.monitor_interval = 0 | |
def separate_track_with_model( | |
args, model, device, track_audio, track_name, meter, augmented_gain | |
): | |
with torch.no_grad(): | |
if ( | |
args.model_loss_params.architecture == "conv_tasnet_mask_on_output" | |
or args.model_loss_params.architecture == "conv_tasnet" | |
): | |
estimates = conv_tasnet_separate( | |
args, | |
model, | |
device, | |
track_audio, | |
track_name, | |
meter=meter, | |
augmented_gain=augmented_gain, | |
) | |
return estimates | |
def main(): | |
parser = argparse.ArgumentParser(description="model test.py") | |
parser.add_argument("--target", type=str, default="all") | |
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL") | |
parser.add_argument( | |
"--use_musdb", | |
type=str2bool, | |
default=True, | |
help="Use musdb test data or just want to inference other samples?", | |
) | |
parser.add_argument("--exp_name", type=str, default="delimit_6_s') | |
parser.add_argument("--manual_output_name", type=str, default=None) | |
parser.add_argument( | |
"--output_directory", type=str, default="/path/to/results" | |
) | |
parser.add_argument("--use_gpu", type=str2bool, default=True) | |
parser.add_arugment("--save_name_as_target", type=str2bool, default=True) | |
parser.add_argument( | |
"--loudnorm_input_lufs", | |
type=float, | |
default=None, | |
help="If you want to use loudnorm, input target lufs", | |
) | |
parser.add_argument( | |
"--use_singletrackset", | |
type=str2bool, | |
default=False, | |
help="Use SingleTrackSet for X-UMX", | |
) | |
parser.add_argument( | |
"--best_model", | |
type=str2bool, | |
default=True, | |
help="Use best model or lastly saved model", | |
) | |
parser.add_argument( | |
"--save_output_loudnorm", | |
type=float, | |
default=None, | |
help="Save loudness normalized outputs or not. If you want to save, input target loudness", | |
) | |
parser.add_argument( | |
"--save_mixed_output", | |
type=float, | |
default=None, | |
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)", | |
) | |
parser.add_argument( | |
"--save_16k_mono", | |
type=str2bool, | |
default=False, | |
help="Save 16k mono wav files for FAD evaluation.", | |
) | |
parser.add_argument( | |
"--save_histogram", | |
type=str2bool, | |
default=False, | |
help="Save histogram of the output. Only valid when the task is 'delimit'", | |
) | |
args, _ = parser.parse_known_args() | |
args.output_dir = f"{args.output_directory}/checkpoint/{args.exp_name}" | |
with open(f"{args.output_dir}/{args.target}.json", "r") as f: | |
args_dict = json.load(f) | |
args_dict = DotMap(args_dict) | |
for key, value in args_dict["args"].items(): | |
if key in list(vars(args).keys()): | |
pass | |
else: | |
setattr(args, key, value) | |
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}" | |
if args.manual_output_name != None: | |
args.test_output_dir = f"{args.output_directory}/test/{args.manual_output_name}" | |
os.makedirs(args.test_output_dir, exist_ok=True) | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu" | |
) | |
###################### Define Models ###################### | |
our_model = load_model_with_args(args) | |
our_model = our_model.to(device) | |
print(our_model) | |
pytorch_total_params = sum( | |
p.numel() for p in our_model.parameters() if p.requires_grad | |
) | |
print("Total number of parameters", pytorch_total_params) | |
# Future work => Torchinfo would be better for this purpose. | |
if args.best_model: | |
target_model_path = f"{args.output_dir}/{args.target}.pth" | |
checkpoint = torch.load(target_model_path, map_location=device) | |
our_model.load_state_dict(checkpoint) | |
else: # when using lastly saved model | |
target_model_path = f"{args.output_dir}/{args.target}.chkpnt" | |
checkpoint = torch.load(target_model_path, map_location=device) | |
our_model.load_state_dict(checkpoint["state_dict"]) | |
our_model.eval() | |
meter = pyln.Meter(44100) | |
if args.use_musdb: | |
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True) | |
for track in tqdm.tqdm(test_tracks): | |
track_name = track.name | |
track_audio = track.audio | |
orig_audio = track_audio.copy() | |
augmented_gain = None | |
print("Now De-limiting : ", track_name) | |
if args.loudnorm_input_lufs: # If you want to use loud-normalized input | |
track_lufs = meter.integrated_loudness(track_audio) | |
augmented_gain = args.loudnorm_input_lufs - track_lufs | |
track_audio = track_audio * db2linear(augmented_gain, eps=0.0) | |
track_audio = ( | |
torch.as_tensor(track_audio.T, dtype=torch.float32) | |
.unsqueeze(0) | |
.to(device) | |
) | |
estimates = separate_track_with_model( | |
args, our_model, device, track_audio, track_name, meter, augmented_gain | |
) | |
if args.save_mixed_output: | |
orig_audio = orig_audio.T | |
track_lufs = meter.integrated_loudness(orig_audio.T) | |
augmented_gain = args.save_output_loudnorm - track_lufs | |
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0) | |
mixed_output = orig_audio * args.save_mixed_output + estimates * ( | |
1 - args.save_mixed_output | |
) | |
sf.write( | |
f"{args.test_output_dir}/{track_name}/{str(args.save_mixed_output)}_mixed.wav", | |
mixed_output.T, | |
args.data_params.sample_rate, | |
) | |
else: | |
test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob( | |
f"{args.data_root}/*.mp3" | |
) | |
for track in tqdm.tqdm(test_tracks): | |
track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "") | |
track_audio, sr = librosa.load( | |
track, sr=None, mono=False | |
) # sr should be 44100 | |
orig_audio = track_audio.copy() | |
if sr != 44100: | |
raise ValueError("Sample rate should be 44100") | |
augmented_gain = None | |
print("Now De-limiting : ", track_name) | |
if args.loudnorm_input_lufs: # If you want to use loud-normalized input | |
track_lufs = meter.integrated_loudness(track_audio.T) | |
augmented_gain = args.loudnorm_input_lufs - track_lufs | |
track_audio = track_audio * db2linear(augmented_gain, eps=0.0) | |
track_audio = ( | |
torch.as_tensor(track_audio, dtype=torch.float32) | |
.unsqueeze(0) | |
.to(device) | |
) | |
estimates = separate_track_with_model( | |
args, our_model, device, track_audio, track_name, meter, augmented_gain | |
) | |
if args.save_mixed_output: | |
track_lufs = meter.integrated_loudness(orig_audio.T) | |
augmented_gain = args.save_output_loudnorm - track_lufs | |
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0) | |
mixed_output = orig_audio * args.save_mixed_output + estimates * ( | |
1 - args.save_mixed_output | |
) | |
sf.write( | |
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav", | |
mixed_output.T, | |
args.data_params.sample_rate, | |
) | |
if __name__ == "__main__": | |
main() | |