# -*- coding: utf-8 -*- """ Faster Demucs separation for MacBook Pro (CPU by default; CUDA-friendly). Key speedups (without breaking outputs): - Use segmented inference (segment ~10-12s) + small overlap to reduce FFT cost and RAM. - Clamp shifts on CPU (1) and keep small on GPU (2) for big speed gains (shifts is linear-time). - Use torch.inference_mode() for faster no-grad path. - Cache a single torchaudio Resample transform. - Cap OMP/MKL threads to avoid oversubscription on laptop CPUs. - Preallocate and accumulate 'instruments' in-place (avoids tensor churn). - Keep outputs on CPU to reduce device swaps; on CUDA, run model+input in float16. Public API preserved: init_demucs() load_model(...) release_model() separate_audio(folder, ...) extract_audio_from_video(folder) separate_all_audio_under_folder(root_folder, ...) """ import os import time import math import gc from typing import Tuple, Optional # ---- Threading caps (greatly helps FFT-heavy CPU runs) ---- MAX_THREADS = max(1, min(8, os.cpu_count() or 4)) os.environ.setdefault("OMP_NUM_THREADS", str(MAX_THREADS)) os.environ.setdefault("MKL_NUM_THREADS", str(MAX_THREADS)) from loguru import logger import torch import torchaudio import torchaudio.functional as AF from torchaudio.transforms import Resample # Demucs programmatic API from demucs import pretrained from demucs.apply import apply_model from .utils import save_wav, normalize_wav # noqa: F401 if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True # ----------------------------- # Globals # ----------------------------- auto_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') _model: Optional[torch.nn.Module] = None _model_loaded: bool = False current_model_config = {} _resampler: Optional[Resample] = None _TARGET_SR = 44100 def _pick_device(device: str): return auto_device if device == 'auto' else torch.device(device) def _defaults_for_hardware(): if torch.cuda.is_available(): return dict(shifts=2, segment=12.0, overlap=0.10, dtype=torch.float16) else: return dict(shifts=1, segment=10.0, overlap=0.10, dtype=torch.float32) def init_demucs(): global _model, _model_loaded if not _model_loaded: _model = load_model() _model_loaded = True else: logger.info("Demucs model already loaded — skipping initialization.") def load_model(model_name: str = "htdemucs_ft", device: str = 'auto', progress: bool = True, shifts: int = 5): global _model, _model_loaded, current_model_config hw = _defaults_for_hardware() shifts = int(shifts) if shifts is not None else hw["shifts"] if (not torch.cuda.is_available()) and shifts > 1: shifts = hw["shifts"] requested_config = { 'model_name': model_name, 'device': 'auto' if device == 'auto' else str(device), 'shifts': shifts } if _model is not None and current_model_config == requested_config: logger.info('Demucs model already loaded with the same configuration — reusing existing model.') return _model if _model is not None: logger.info('Demucs configuration changed — reloading model.') release_model() logger.info(f'Loading Demucs model: {model_name}') t_start = time.time() device_to_use = _pick_device(device) model = pretrained.get_model(model_name) model.eval() model.to(device_to_use) if torch.cuda.is_available(): model.half() current_model_config = requested_config _model = model _model_loaded = True logger.info(f'Demucs model loaded successfully in {time.time() - t_start:.2f}s.') return _model def release_model(): global _model, _model_loaded, current_model_config if _model is not None: logger.info('Releasing Demucs model resources...') _model = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() _model_loaded = False current_model_config = {} logger.info('Demucs model resources released.') def _get_resampler(orig_sr: int, new_sr: int): global _resampler if _resampler is None or _resampler.orig_freq != orig_sr or _resampler.new_freq != new_sr: _resampler = Resample(orig_freq=orig_sr, new_freq=new_sr) return _resampler def _load_audio_as_tensor(path: str, target_sr: int = _TARGET_SR) -> torch.Tensor: wav, sr = torchaudio.load(path) if sr != target_sr: resampler = _get_resampler(sr, target_sr) wav = resampler(wav) sr = target_sr if wav.shape[0] > 2: wav = wav[:2, :] return wav.contiguous().float() def _pick_aligned_segment_seconds(model, clip_num_samples: int, sr: int) -> tuple[float | None, float]: training_len = getattr(model, "training_length", None) model_sr = getattr(model, "samplerate", sr) if not training_len or model_sr <= 0: return None, 0.10 base_seg_s = training_len / float(model_sr) T = clip_num_samples / float(sr) k = int(T // base_seg_s) k = max(1, min(k, 3)) segment_s = k * base_seg_s overlap = 0.10 if segment_s > 0.15 else 0.05 if overlap >= segment_s: overlap = max(0.0, 0.5 * segment_s) return segment_s, overlap def separate_audio(folder: str, model_name: str = "htdemucs_ft", device: str = 'auto', progress: bool = True, shifts: int = 5): global _model audio_path = os.path.join(folder, 'audio.wav') if not os.path.exists(audio_path): return None, None vocal_output_path = os.path.join(folder, 'audio_vocals.wav') instruments_output_path = os.path.join(folder, 'audio_instruments.wav') if os.path.exists(vocal_output_path) and os.path.exists(instruments_output_path): logger.info(f'Audio already separated: {folder}') return vocal_output_path, instruments_output_path logger.info(f'Separating audio: {folder}') try: need_reload = ( (not _model_loaded) or current_model_config.get('model_name') != model_name or (current_model_config.get('device') == 'auto') != (device == 'auto') or current_model_config.get('shifts') != int(shifts) ) if need_reload: load_model(model_name, device, progress, shifts) device_to_use = _pick_device(device) t_start = time.time() wav = _load_audio_as_tensor(audio_path, target_sr=_TARGET_SR) C, T_samples = wav.shape segment_s, overlap = _pick_aligned_segment_seconds(_model, T_samples, _TARGET_SR) wav_in = wav.unsqueeze(0) if device_to_use.type == 'cuda': wav_in = wav_in.to(device_to_use, non_blocking=True).half() else: wav_in = wav_in.to(device_to_use, non_blocking=True) eff_shifts = current_model_config.get('shifts', 1) with torch.inference_mode(): sources_tensor = apply_model( _model, wav_in, shifts=eff_shifts, progress=progress, overlap=overlap, split=True, segment=segment_s, )[0] sources_tensor = sources_tensor.to(dtype=torch.float32, device='cpu') name_to_src = {name: sources_tensor[i] for i, name in enumerate(_model.sources)} vocals = name_to_src.get('vocals') if vocals is None: logger.warning("This Demucs model does not include a 'vocals' stem — generating silent vocals.") vocals = torch.zeros_like(wav) instruments = torch.zeros_like(vocals) for k, v in name_to_src.items(): if k != 'vocals': instruments.add_(v) if not instruments.abs().sum().item(): instruments = wav - vocals save_wav(vocals.transpose(0, 1).numpy(), vocal_output_path, sample_rate=_TARGET_SR) logger.info(f'Saved vocals: {vocal_output_path}') save_wav(instruments.transpose(0, 1).numpy(), instruments_output_path, sample_rate=_TARGET_SR) logger.info(f'Saved accompaniment: {instruments_output_path}') if torch.cuda.is_available(): torch.cuda.synchronize() logger.info(f'Audio separation complete in {time.time() - t_start:.2f}s.') return vocal_output_path, instruments_output_path except Exception as e: logger.error(f'Audio separation failed: {str(e)}') release_model() raise def extract_audio_from_video(folder: str) -> bool: video_path = os.path.join(folder, 'download.mp4') if not os.path.exists(video_path): return False audio_path = os.path.join(folder, 'audio.wav') if os.path.exists(audio_path): logger.info(f'Audio already extracted: {folder}') return True logger.info(f'Extracting audio from video: {folder}') os.system( f'ffmpeg -loglevel error -i "{video_path}" -vn -acodec pcm_s16le -ar {_TARGET_SR} -ac 2 "{audio_path}"' ) time.sleep(0.5) logger.info(f'Audio extraction complete: {folder}') return True def separate_all_audio_under_folder(root_folder: str, model_name: str = "htdemucs_ft", device: str = 'auto', progress: bool = True, shifts: int = 5): vocal_output_path, instruments_output_path = None, None try: for subdir, dirs, files in os.walk(root_folder): files_set = set(files) if 'download.mp4' not in files_set: continue if 'audio.wav' not in files_set: extract_audio_from_video(subdir) files_set = set(os.listdir(subdir)) if 'audio_vocals.wav' not in files_set or 'audio_instruments.wav' not in files_set: vocal_output_path, instruments_output_path = separate_audio( subdir, model_name, device, progress, shifts ) else: vocal_output_path = os.path.join(subdir, 'audio_vocals.wav') instruments_output_path = os.path.join(subdir, 'audio_instruments.wav') logger.info(f'Audio already separated: {subdir}') logger.info(f'All audio separation completed under: {root_folder}') return f'All audio separated: {root_folder}', vocal_output_path, instruments_output_path except Exception as e: logger.error(f'Error during audio separation: {str(e)}') release_model() raise if __name__ == '__main__': folder = r"videos" separate_all_audio_under_folder(folder, shifts=1)