|
|
|
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
|
|
|
import argparse |
|
import time |
|
import librosa |
|
from tqdm.auto import tqdm |
|
import sys |
|
import os |
|
import glob |
|
import torch |
|
import soundfile as sf |
|
import torch.nn as nn |
|
from datetime import datetime |
|
import numpy as np |
|
import librosa |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(current_dir) |
|
|
|
from utils import demix, get_model_from_config, normalize_audio, denormalize_audio |
|
from utils import prefer_target_instrument, apply_tta, load_start_checkpoint, load_lora_weights |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def shorten_filename(filename, max_length=30): |
|
""" |
|
Shortens a filename to a specified maximum length |
|
|
|
Args: |
|
filename (str): The filename to be shortened |
|
max_length (int): Maximum allowed length for the filename |
|
|
|
Returns: |
|
str: Shortened filename |
|
""" |
|
base, ext = os.path.splitext(filename) |
|
if len(base) <= max_length: |
|
return filename |
|
|
|
|
|
shortened = base[:15] + "..." + base[-10:] + ext |
|
return shortened |
|
|
|
def get_soundfile_subtype(pcm_type, is_float=False): |
|
""" |
|
PCM türüne göre uygun soundfile subtypei belirle |
|
|
|
Args: |
|
pcm_type (str): PCM türü ('PCM_16', 'PCM_24', 'FLOAT') |
|
is_float (bool): Float formatı kullanılıp kullanılmayacağı |
|
|
|
Returns: |
|
str: Soundfile subtype |
|
""" |
|
if is_float: |
|
return 'FLOAT' |
|
|
|
subtype_map = { |
|
'PCM_16': 'PCM_16', |
|
'PCM_24': 'PCM_24', |
|
'FLOAT': 'FLOAT' |
|
} |
|
return subtype_map.get(pcm_type, 'FLOAT') |
|
|
|
def run_folder(model, args, config, device, verbose: bool = False): |
|
start_time = time.time() |
|
model.eval() |
|
|
|
mixture_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.*'))) |
|
sample_rate = getattr(config.audio, 'sample_rate', 44100) |
|
|
|
print(f"Total files found: {len(mixture_paths)}. Using sample rate: {sample_rate}") |
|
|
|
instruments = prefer_target_instrument(config)[:] |
|
os.makedirs(args.store_dir, exist_ok=True) |
|
|
|
|
|
total_files = len(mixture_paths) |
|
current_file = 0 |
|
|
|
for path in mixture_paths: |
|
try: |
|
|
|
current_file += 1 |
|
print(f"Processing file {current_file}/{total_files}") |
|
|
|
mix, sr = librosa.load(path, sr=sample_rate, mono=False) |
|
mix_orig = mix.copy() |
|
|
|
if 'normalize' in config.inference: |
|
if config.inference['normalize'] is True: |
|
mix, norm_params = normalize_audio(mix) |
|
|
|
|
|
total_duration = 0.0 |
|
total_steps = 100.0 |
|
current_progress = 0.0 |
|
|
|
|
|
start_time_step = time.time() |
|
waveforms_orig = demix(config, model, mix, device, model_type=args.model_type) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
current_progress += 30.0 * (step_duration / total_duration) if total_duration > 0 else 30.0 |
|
print(f"Progress: {min(current_progress, 30.0):.1f}%") |
|
|
|
if args.use_tta: |
|
|
|
start_time_step = time.time() |
|
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(30.0 + i, 50.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
if args.demud_phaseremix_inst: |
|
print(f"Demudding track (phase remix - instrumental): {path}") |
|
instr = 'vocals' if 'vocals' in instruments else instruments[0] |
|
instruments.append('instrumental_phaseremix') |
|
if 'instrumental' not in instruments and 'Instrumental' not in instruments: |
|
mix_modified = mix_orig - 2*waveforms_orig[instr] |
|
mix_modified_ = mix_modified.copy() |
|
|
|
start_time_step = time.time() |
|
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(50.0 + i, 60.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
if args.use_tta: |
|
start_time_step = time.time() |
|
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(60.0 + i, 70.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr] |
|
else: |
|
mix_modified = 2*waveforms_orig[instr] - mix_orig |
|
mix_modified_ = mix_modified.copy() |
|
|
|
start_time_step = time.time() |
|
waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(50.0 + i, 60.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
if args.use_tta: |
|
start_time_step = time.time() |
|
waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(60.0 + i, 70.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr] |
|
current_progress = 70.0 |
|
|
|
if args.extract_instrumental: |
|
instr = 'vocals' if 'vocals' in instruments else instruments[0] |
|
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr] |
|
if 'instrumental' not in instruments: |
|
instruments.append('instrumental') |
|
|
|
|
|
start_time_step = time.time() |
|
for instr in instruments: |
|
estimates = waveforms_orig[instr] |
|
if 'normalize' in config.inference: |
|
if config.inference['normalize'] is True: |
|
estimates = denormalize_audio(estimates, norm_params) |
|
|
|
|
|
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT') |
|
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav' |
|
|
|
|
|
if codec == 'flac': |
|
subtype = get_soundfile_subtype(args.pcm_type, is_float) |
|
else: |
|
subtype = get_soundfile_subtype('FLOAT', is_float) |
|
|
|
shortened_filename = shorten_filename(os.path.basename(path)) |
|
output_filename = f"{shortened_filename}_{instr}.{codec}" |
|
output_path = os.path.join(args.store_dir, output_filename) |
|
|
|
sf.write(output_path, estimates.T, sr, subtype=subtype) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(70.0 + i, 90.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
|
|
start_time_step = time.time() |
|
time.sleep(0.1) |
|
step_duration = time.time() - start_time_step |
|
total_duration += step_duration |
|
progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 |
|
for i in np.arange(0.1, progress_increment + 0.1, 0.1): |
|
current_progress = min(90.0 + i, 100.0) |
|
time.sleep(0.001) |
|
print(f"Progress: {current_progress:.1f}%") |
|
|
|
except Exception as e: |
|
print(f'Cannot read track: {path}') |
|
print(f'Error message: {str(e)}') |
|
continue |
|
|
|
print(f"Elapsed time: {time.time() - start_time:.2f} seconds.") |
|
|
|
def proc_folder(args): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_type", type=str, default='mdx23c', |
|
help="Model type (bandit, bs_roformer, mdx23c, etc.)") |
|
parser.add_argument("--config_path", type=str, help="Path to config file") |
|
parser.add_argument("--demud_phaseremix_inst", action='store_true', help="demud_phaseremix_inst") |
|
parser.add_argument("--start_check_point", type=str, default='', |
|
help="Initial checkpoint to valid weights") |
|
parser.add_argument("--input_folder", type=str, help="Folder with mixtures to process") |
|
parser.add_argument("--audio_path", type=str, help="Path to a single audio file to process") |
|
parser.add_argument("--store_dir", default="", type=str, help="Path to store results") |
|
parser.add_argument("--device_ids", nargs='+', type=int, default=0, |
|
help='List of GPU IDs') |
|
parser.add_argument("--extract_instrumental", action='store_true', |
|
help="Invert vocals to get instrumental if provided") |
|
parser.add_argument("--force_cpu", action='store_true', |
|
help="Force the use of CPU even if CUDA is available") |
|
parser.add_argument("--flac_file", action='store_true', |
|
help="Output flac file instead of wav") |
|
parser.add_argument("--export_format", type=str, |
|
choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], |
|
default='flac PCM_24', |
|
help="Export format and PCM type") |
|
parser.add_argument("--pcm_type", type=str, |
|
choices=['PCM_16', 'PCM_24'], |
|
default='PCM_24', |
|
help="PCM type for FLAC files") |
|
parser.add_argument("--use_tta", action='store_true', |
|
help="Enable test time augmentation") |
|
parser.add_argument("--lora_checkpoint", type=str, default='', |
|
help="Initial checkpoint to LoRA weights") |
|
|
|
|
|
parsed_args = parser.parse_args(args) |
|
|
|
|
|
print(f"Audio path provided: {parsed_args.audio_path}") |
|
|
|
if args is None: |
|
args = parser.parse_args() |
|
else: |
|
args = parser.parse_args(args) |
|
|
|
|
|
device = "cpu" |
|
if args.force_cpu: |
|
device = "cpu" |
|
elif torch.cuda.is_available(): |
|
print('CUDA is available, use --force_cpu to disable it.') |
|
device = f'cuda:{args.device_ids[0]}' if type(args.device_ids) == list else f'cuda:{args.device_ids}' |
|
elif torch.backends.mps.is_available(): |
|
device = "mps" |
|
|
|
print("Using device: ", device) |
|
|
|
model_load_start_time = time.time() |
|
torch.backends.cudnn.benchmark = True |
|
|
|
model, config = get_model_from_config(args.model_type, args.config_path) |
|
|
|
if args.start_check_point != '': |
|
load_start_checkpoint(args, model, type_='inference') |
|
|
|
print("Instruments: {}".format(config.training.instruments)) |
|
|
|
|
|
if type(args.device_ids) == list and len(args.device_ids) > 1 and not args.force_cpu: |
|
model = nn.DataParallel(model, device_ids=args.device_ids) |
|
|
|
model = model.to(device) |
|
|
|
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) |
|
|
|
run_folder(model, args, config, device, verbose=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
proc_folder(None) |