# coding: utf-8 __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' import argparse import time import librosa from tqdm.auto import tqdm import sys import os import glob import torch import soundfile as sf import torch.nn as nn from datetime import datetime import numpy as np import librosa # Using the embedded version of Python can also correctly import the utils module. current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) from utils import demix, get_model_from_config, normalize_audio, denormalize_audio from utils import prefer_target_instrument, apply_tta, load_start_checkpoint, load_lora_weights import warnings warnings.filterwarnings("ignore") def shorten_filename(filename, max_length=30): """ Shortens a filename to a specified maximum length Args: filename (str): The filename to be shortened max_length (int): Maximum allowed length for the filename Returns: str: Shortened filename """ base, ext = os.path.splitext(filename) if len(base) <= max_length: return filename # Take first 15 and last 10 characters shortened = base[:15] + "..." + base[-10:] + ext return shortened def get_soundfile_subtype(pcm_type, is_float=False): """ PCM türüne göre uygun soundfile subtypei belirle Args: pcm_type (str): PCM türü ('PCM_16', 'PCM_24', 'FLOAT') is_float (bool): Float formatı kullanılıp kullanılmayacağı Returns: str: Soundfile subtype """ if is_float: return 'FLOAT' subtype_map = { 'PCM_16': 'PCM_16', 'PCM_24': 'PCM_24', 'FLOAT': 'FLOAT' } return subtype_map.get(pcm_type, 'FLOAT') def run_folder(model, args, config, device, verbose: bool = False): start_time = time.time() model.eval() mixture_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.*'))) sample_rate = getattr(config.audio, 'sample_rate', 44100) print(f"Total files found: {len(mixture_paths)}. Using sample rate: {sample_rate}") instruments = prefer_target_instrument(config)[:] os.makedirs(args.store_dir, exist_ok=True) # Progress tracking total_files = len(mixture_paths) current_file = 0 for path in mixture_paths: try: # Dosya işleme başlangıcı current_file += 1 print(f"Processing file {current_file}/{total_files}") mix, sr = librosa.load(path, sr=sample_rate, mono=False) mix_orig = mix.copy() if 'normalize' in config.inference: if config.inference['normalize'] is True: mix, norm_params = normalize_audio(mix) # Toplam işlem sürelerini izlemek için başlangıç zamanı total_duration = 0.0 total_steps = 100.0 # Toplam %100 current_progress = 0.0 # Model yükleme ve ilk ayrıştırma (%0 -> %30) start_time_step = time.time() waveforms_orig = demix(config, model, mix, device, model_type=args.model_type) step_duration = time.time() - start_time_step total_duration += step_duration current_progress += 30.0 * (step_duration / total_duration) if total_duration > 0 else 30.0 print(f"Progress: {min(current_progress, 30.0):.1f}%") if args.use_tta: # TTA işlemi (%30 -> %50) start_time_step = time.time() waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(30.0 + i, 50.0) time.sleep(0.001) # Küçük bir gecikme, gerçek işlem için gereksiz olabilir print(f"Progress: {current_progress:.1f}%") if args.demud_phaseremix_inst: print(f"Demudding track (phase remix - instrumental): {path}") instr = 'vocals' if 'vocals' in instruments else instruments[0] instruments.append('instrumental_phaseremix') if 'instrumental' not in instruments and 'Instrumental' not in instruments: mix_modified = mix_orig - 2*waveforms_orig[instr] mix_modified_ = mix_modified.copy() start_time_step = time.time() waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(50.0 + i, 60.0) time.sleep(0.001) print(f"Progress: {current_progress:.1f}%") if args.use_tta: start_time_step = time.time() waveforms_modified = apply_tta(config, model, mix_modified, waveforms_modified, device, args.model_type) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(60.0 + i, 70.0) time.sleep(0.001) print(f"Progress: {current_progress:.1f}%") waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr] else: mix_modified = 2*waveforms_orig[instr] - mix_orig mix_modified_ = mix_modified.copy() start_time_step = time.time() waveforms_modified = demix(config, model, mix_modified, device, model_type=args.model_type) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(50.0 + i, 60.0) time.sleep(0.001) print(f"Progress: {current_progress:.1f}%") if args.use_tta: start_time_step = time.time() waveforms_modified = apply_tta(config, model, mix_modified, waveforms_orig, device, args.model_type) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(60.0 + i, 70.0) time.sleep(0.001) print(f"Progress: {current_progress:.1f}%") waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr] current_progress = 70.0 if args.extract_instrumental: instr = 'vocals' if 'vocals' in instruments else instruments[0] waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr] if 'instrumental' not in instruments: instruments.append('instrumental') # Dosya yazma ve finalize (%70 -> %100) start_time_step = time.time() for instr in instruments: estimates = waveforms_orig[instr] if 'normalize' in config.inference: if config.inference['normalize'] is True: estimates = denormalize_audio(estimates, norm_params) # Dosya formatı ve PCM türü belirleme is_float = getattr(args, 'export_format', '').startswith('wav FLOAT') codec = 'flac' if getattr(args, 'flac_file', False) else 'wav' # Subtype belirleme if codec == 'flac': subtype = get_soundfile_subtype(args.pcm_type, is_float) else: subtype = get_soundfile_subtype('FLOAT', is_float) shortened_filename = shorten_filename(os.path.basename(path)) output_filename = f"{shortened_filename}_{instr}.{codec}" output_path = os.path.join(args.store_dir, output_filename) sf.write(output_path, estimates.T, sr, subtype=subtype) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 20.0 * (step_duration / total_duration) if total_duration > 0 else 20.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(70.0 + i, 90.0) time.sleep(0.001) print(f"Progress: {current_progress:.1f}%") # Finalize (%90 -> %100) start_time_step = time.time() time.sleep(0.1) # Finalize için küçük bir bekleme (gerçek işlem süresiyle değiştirilebilir) step_duration = time.time() - start_time_step total_duration += step_duration progress_increment = 10.0 * (step_duration / total_duration) if total_duration > 0 else 10.0 for i in np.arange(0.1, progress_increment + 0.1, 0.1): current_progress = min(90.0 + i, 100.0) time.sleep(0.001) print(f"Progress: {current_progress:.1f}%") except Exception as e: print(f'Cannot read track: {path}') print(f'Error message: {str(e)}') continue print(f"Elapsed time: {time.time() - start_time:.2f} seconds.") def proc_folder(args): parser = argparse.ArgumentParser() parser.add_argument("--model_type", type=str, default='mdx23c', help="Model type (bandit, bs_roformer, mdx23c, etc.)") parser.add_argument("--config_path", type=str, help="Path to config file") parser.add_argument("--demud_phaseremix_inst", action='store_true', help="demud_phaseremix_inst") parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights") parser.add_argument("--input_folder", type=str, help="Folder with mixtures to process") parser.add_argument("--audio_path", type=str, help="Path to a single audio file to process") # Yeni argüman parser.add_argument("--store_dir", default="", type=str, help="Path to store results") parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='List of GPU IDs') parser.add_argument("--extract_instrumental", action='store_true', help="Invert vocals to get instrumental if provided") parser.add_argument("--force_cpu", action='store_true', help="Force the use of CPU even if CUDA is available") parser.add_argument("--flac_file", action='store_true', help="Output flac file instead of wav") parser.add_argument("--export_format", type=str, choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], default='flac PCM_24', help="Export format and PCM type") parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', help="PCM type for FLAC files") parser.add_argument("--use_tta", action='store_true', help="Enable test time augmentation") parser.add_argument("--lora_checkpoint", type=str, default='', help="Initial checkpoint to LoRA weights") # Argümanları ayrıştır parsed_args = parser.parse_args(args) # Burada parsed_args.audio_path ile ses dosyası yolunu kullanabilirsiniz print(f"Audio path provided: {parsed_args.audio_path}") if args is None: args = parser.parse_args() else: args = parser.parse_args(args) # Cihaz seçimi device = "cpu" if args.force_cpu: device = "cpu" elif torch.cuda.is_available(): print('CUDA is available, use --force_cpu to disable it.') device = f'cuda:{args.device_ids[0]}' if type(args.device_ids) == list else f'cuda:{args.device_ids}' elif torch.backends.mps.is_available(): device = "mps" print("Using device: ", device) model_load_start_time = time.time() torch.backends.cudnn.benchmark = True model, config = get_model_from_config(args.model_type, args.config_path) if args.start_check_point != '': load_start_checkpoint(args, model, type_='inference') print("Instruments: {}".format(config.training.instruments)) # Çoklu CUDA GPU kullanımı if type(args.device_ids) == list and len(args.device_ids) > 1 and not args.force_cpu: model = nn.DataParallel(model, device_ids=args.device_ids) model = model.to(device) print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time)) run_folder(model, args, config, device, verbose=True) if __name__ == "__main__": proc_folder(None)