Spaces:
Running
Running
| import os | |
| import re | |
| import sys | |
| import time | |
| import yt_dlp | |
| import shutil | |
| import librosa | |
| import logging | |
| import argparse | |
| import warnings | |
| import logging.handlers | |
| import soundfile as sf | |
| import noisereduce as nr | |
| from distutils.util import strtobool | |
| from pydub import AudioSegment, silence | |
| now_dir = os.getcwd() | |
| sys.path.append(now_dir) | |
| from main.configs.config import Config | |
| from main.library.algorithm.separator import Separator | |
| translations = Config().translations | |
| log_file = os.path.join("assets", "logs", "create_dataset.log") | |
| logger = logging.getLogger(__name__) | |
| if logger.hasHandlers(): logger.handlers.clear() | |
| else: | |
| console_handler = logging.StreamHandler() | |
| console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
| console_handler.setFormatter(console_formatter) | |
| console_handler.setLevel(logging.INFO) | |
| file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8') | |
| file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
| file_handler.setFormatter(file_formatter) | |
| file_handler.setLevel(logging.DEBUG) | |
| logger.addHandler(console_handler) | |
| logger.addHandler(file_handler) | |
| logger.setLevel(logging.DEBUG) | |
| def parse_arguments() -> tuple: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--input_audio", type=str, required=True) | |
| parser.add_argument("--output_dataset", type=str, default="./dataset") | |
| parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False) | |
| parser.add_argument("--resample_sr", type=int, default=44100) | |
| parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False) | |
| parser.add_argument("--clean_strength", type=float, default=0.7) | |
| parser.add_argument("--separator_music", type=lambda x: bool(strtobool(x)), default=False) | |
| parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False) | |
| parser.add_argument("--kim_vocal_version", type=int, default=2) | |
| parser.add_argument("--overlap", type=float, default=0.25) | |
| parser.add_argument("--segments_size", type=int, default=256) | |
| parser.add_argument("--mdx_hop_length", type=int, default=1024) | |
| parser.add_argument("--mdx_batch_size", type=int, default=1) | |
| parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False) | |
| parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False) | |
| parser.add_argument("--skip_start_audios", type=str, default="0") | |
| parser.add_argument("--skip_end_audios", type=str, default="0") | |
| args = parser.parse_args() | |
| return args | |
| dataset_temp = os.path.join("dataset_temp") | |
| def main(): | |
| args = parse_arguments() | |
| input_audio = args.input_audio | |
| output_dataset = args.output_dataset | |
| resample = args.resample | |
| resample_sr = args.resample_sr | |
| clean_dataset = args.clean_dataset | |
| clean_strength = args.clean_strength | |
| separator_music = args.separator_music | |
| separator_reverb = args.separator_reverb | |
| kim_vocal_version = args.kim_vocal_version | |
| overlap = args.overlap | |
| segments_size = args.segments_size | |
| hop_length = args.mdx_hop_length | |
| batch_size = args.mdx_batch_size | |
| denoise_mdx = args.denoise_mdx | |
| skip = args.skip | |
| skip_start_audios = args.skip_start_audios | |
| skip_end_audios = args.skip_end_audios | |
| logger.debug(f"{translations['audio_path']}: {input_audio}") | |
| logger.debug(f"{translations['output_path']}: {output_dataset}") | |
| logger.debug(f"{translations['resample']}: {resample}") | |
| if resample: logger.debug(f"{translations['sample_rate']}: {resample_sr}") | |
| logger.debug(f"{translations['clear_dataset']}: {clean_dataset}") | |
| if clean_dataset: logger.debug(f"{translations['clean_strength']}: {clean_strength}") | |
| logger.debug(f"{translations['separator_music']}: {separator_music}") | |
| logger.debug(f"{translations['dereveb_audio']}: {separator_reverb}") | |
| if separator_music: logger.debug(f"{translations['training_version']}: {kim_vocal_version}") | |
| logger.debug(f"{translations['segments_size']}: {segments_size}") | |
| logger.debug(f"{translations['overlap']}: {overlap}") | |
| logger.debug(f"Hop length: {hop_length}") | |
| logger.debug(f"{translations['batch_size']}: {batch_size}") | |
| logger.debug(f"{translations['denoise_mdx']}: {denoise_mdx}") | |
| logger.debug(f"{translations['skip']}: {skip}") | |
| if skip: logger.debug(f"{translations['skip_start']}: {skip_start_audios}") | |
| if skip: logger.debug(f"{translations['skip_end']}: {skip_end_audios}") | |
| if kim_vocal_version != 1 and kim_vocal_version != 2: raise ValueError(translations["version_not_valid"]) | |
| if separator_reverb and not separator_music: raise ValueError(translations["create_dataset_value_not_valid"]) | |
| start_time = time.time() | |
| try: | |
| paths = [] | |
| if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True) | |
| urls = input_audio.replace(", ", ",").split(",") | |
| for url in urls: | |
| path = downloader(url, urls.index(url)) | |
| paths.append(path) | |
| if skip: | |
| skip_start_audios = skip_start_audios.replace(", ", ",").split(",") | |
| skip_end_audios = skip_end_audios.replace(", ", ",").split(",") | |
| if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths): | |
| logger.warning(translations["skip<audio"]) | |
| sys.exit(1) | |
| elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths): | |
| logger.warning(translations["skip>audio"]) | |
| sys.exit(1) | |
| else: | |
| for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios): | |
| skip_start(audio, skip_start_audio) | |
| skip_end(audio, skip_end_audio) | |
| if separator_music: | |
| separator_paths = [] | |
| for audio in paths: | |
| vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size) | |
| if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size) | |
| separator_paths.append(vocals) | |
| paths = separator_paths | |
| processed_paths = [] | |
| for audio in paths: | |
| output = process_audio(audio) | |
| processed_paths.append(output) | |
| paths = processed_paths | |
| for audio_path in paths: | |
| data, sample_rate = sf.read(audio_path) | |
| if resample_sr != sample_rate and resample_sr > 0 and resample: | |
| data = librosa.resample(data, orig_sr=sample_rate, target_sr=resample_sr) | |
| sample_rate = resample_sr | |
| if clean_dataset: data = nr.reduce_noise(y=data, prop_decrease=clean_strength) | |
| sf.write(audio_path, data, sample_rate) | |
| except Exception as e: | |
| raise RuntimeError(f"{translations['create_dataset_error']}: {e}") | |
| finally: | |
| for audio in paths: | |
| shutil.move(audio, output_dataset) | |
| if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True) | |
| elapsed_time = time.time() - start_time | |
| logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}")) | |
| def downloader(url, name): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| ydl_opts = { | |
| 'format': 'bestaudio/best', | |
| 'outtmpl': os.path.join(dataset_temp, f"{name}"), | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| 'preferredquality': '192', | |
| }], | |
| 'noplaylist': True, | |
| 'verbose': False, | |
| } | |
| logger.info(f"{translations['starting_download']}: {url}...") | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.extract_info(url) | |
| logger.info(f"{translations['download_success']}: {url}") | |
| return os.path.join(dataset_temp, f"{name}" + ".wav") | |
| def skip_start(input_file, seconds): | |
| data, sr = sf.read(input_file) | |
| total_duration = len(data) / sr | |
| if seconds <= 0: logger.warning(translations["=<0"]) | |
| elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}")) | |
| else: | |
| logger.info(f"{translations['skip_start']}: {input_file}...") | |
| sf.write(input_file, data[int(seconds * sr):], sr) | |
| logger.info(translations["skip_start_audio"].format(input_file=input_file)) | |
| def skip_end(input_file, seconds): | |
| data, sr = sf.read(input_file) | |
| total_duration = len(data) / sr | |
| if seconds <= 0: logger.warning(translations["=<0"]) | |
| elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}")) | |
| else: | |
| logger.info(f"{translations['skip_end']}: {input_file}...") | |
| sf.write(input_file, data[:-int(seconds * sr)], sr) | |
| logger.info(translations["skip_end_audio"].format(input_file=input_file)) | |
| def process_audio(file_path): | |
| try: | |
| song = AudioSegment.from_file(file_path) | |
| nonsilent_parts = silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70) | |
| cut_files = [] | |
| for i, (start_i, end_i) in enumerate(nonsilent_parts): | |
| chunk = song[start_i:end_i] | |
| if len(chunk) >= 30: | |
| chunk_file_path = os.path.join(os.path.dirname(file_path), f"chunk{i}.wav") | |
| if os.path.exists(chunk_file_path): os.remove(chunk_file_path) | |
| chunk.export(chunk_file_path, format="wav") | |
| cut_files.append(chunk_file_path) | |
| else: logger.warning(translations["skip_file"].format(i=i, chunk=len(chunk))) | |
| logger.info(f"{translations['split_total']}: {len(cut_files)}") | |
| def extract_number(filename): | |
| match = re.search(r'_(\d+)', filename) | |
| return int(match.group(1)) if match else 0 | |
| cut_files = sorted(cut_files, key=extract_number) | |
| combined = AudioSegment.empty() | |
| for file in cut_files: | |
| combined += AudioSegment.from_file(file) | |
| output_path = os.path.splitext(file_path)[0] + "_processed" + ".wav" | |
| logger.info(translations["merge_audio"]) | |
| combined.export(output_path, format="wav") | |
| return output_path | |
| except Exception as e: | |
| raise RuntimeError(f"{translations['process_audio_error']}: {e}") | |
| def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size): | |
| if not os.path.exists(input): | |
| logger.warning(translations["input_not_valid"]) | |
| return None | |
| if not os.path.exists(output): | |
| logger.warning(translations["output_not_valid"]) | |
| return None | |
| model = f"Kim_Vocal_{version}.onnx" | |
| logger.info(translations["separator_process"].format(input=input)) | |
| output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise) | |
| for f in output_separator: | |
| path = os.path.join(output, f) | |
| if not os.path.exists(path): logger.error(translations["not_found"].format(name=path)) | |
| if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav") | |
| elif '_(Vocals)_' in f: | |
| rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav" | |
| os.rename(path, rename_file) | |
| logger.info(f": {rename_file}") | |
| return rename_file | |
| def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size): | |
| reverb_models = "Reverb_HQ_By_FoxJoy.onnx" | |
| if not os.path.exists(input): | |
| logger.warning(translations["input_not_valid"]) | |
| return None | |
| if not os.path.exists(output): | |
| logger.warning(translations["output_not_valid"]) | |
| return None | |
| logger.info(f"{translations['dereverb']}: {input}...") | |
| output_dereverb = separator_main(audio_file=input, model_filename=reverb_models, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise) | |
| for f in output_dereverb: | |
| path = os.path.join(output, f) | |
| if not os.path.exists(path): logger.error(translations["not_found"].format(name=path)) | |
| if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav") | |
| elif '_(No Reverb)_' in f: | |
| rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav" | |
| os.rename(path, rename_file) | |
| logger.info(f"{translations['dereverb_success']}: {rename_file}") | |
| return rename_file | |
| def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True): | |
| separator = Separator( | |
| log_formatter=file_formatter, | |
| log_level=logging.INFO, | |
| output_dir=output_dir, | |
| output_format=output_format, | |
| output_bitrate=None, | |
| normalization_threshold=0.9, | |
| output_single_stem=None, | |
| invert_using_spec=False, | |
| sample_rate=44100, | |
| mdx_params={ | |
| "hop_length": mdx_hop_length, | |
| "segment_size": mdx_segment_size, | |
| "overlap": mdx_overlap, | |
| "batch_size": mdx_batch_size, | |
| "enable_denoise": mdx_enable_denoise, | |
| }, | |
| ) | |
| separator.load_model(model_filename=model_filename) | |
| return separator.separate(audio_file) | |
| if __name__ == "__main__": main() |