# coding: utf-8 __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' import argparse import numpy as np import torch import torch.nn as nn import yaml import os import soundfile as sf import matplotlib.pyplot as plt from ml_collections import ConfigDict from omegaconf import OmegaConf from tqdm.auto import tqdm from typing import Dict, List, Tuple, Any, Union import loralib as lora def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]: """ Load the configuration from the specified path based on the model type. Parameters: ---------- model_type : str The type of model to load (e.g., 'htdemucs', 'mdx23c', etc.). config_path : str The path to the YAML or OmegaConf configuration file. Returns: ------- config : Any The loaded configuration, which can be in different formats (e.g., OmegaConf or ConfigDict). Raises: ------ FileNotFoundError: If the configuration file at `config_path` is not found. ValueError: If there is an error loading the configuration file. """ try: with open(config_path, 'r') as f: if model_type == 'htdemucs': config = OmegaConf.load(config_path) else: config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) return config except FileNotFoundError: raise FileNotFoundError(f"Configuration file not found at {config_path}") except Exception as e: raise ValueError(f"Error loading configuration: {e}") def get_model_from_config(model_type: str, config_path: str) -> Tuple: """ Load the model specified by the model type and configuration file. Parameters: ---------- model_type : str The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.). config_path : str The path to the configuration file (YAML or OmegaConf format). Returns: ------- model : nn.Module or None The initialized model based on the `model_type`, or None if the model type is not recognized. config : Any The configuration used to initialize the model. This could be in different formats depending on the model type (e.g., OmegaConf, ConfigDict). Raises: ------ ValueError: If the `model_type` is unknown or an error occurs during model initialization. """ config = load_config(model_type, config_path) if model_type == 'mdx23c': from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net model = TFC_TDF_net(config) elif model_type == 'htdemucs': from models.demucs4ht import get_model model = get_model(config) elif model_type == 'segm_models': from models.segm_models import Segm_Models_Net model = Segm_Models_Net(config) elif model_type == 'torchseg': from models.torchseg_models import Torchseg_Net model = Torchseg_Net(config) elif model_type == 'mel_band_roformer': from models.bs_roformer import MelBandRoformer model = MelBandRoformer(**dict(config.model)) elif model_type == 'bs_roformer': from models.bs_roformer import BSRoformer model = BSRoformer(**dict(config.model)) elif model_type == 'swin_upernet': from models.upernet_swin_transformers import Swin_UperNet_Model model = Swin_UperNet_Model(config) elif model_type == 'bandit': from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) elif model_type == 'bandit_v2': from models.bandit_v2.bandit import Bandit model = Bandit(**config.kwargs) elif model_type == 'scnet_unofficial': from models.scnet_unofficial import SCNet model = SCNet(**config.model) elif model_type == 'scnet': from models.scnet import SCNet model = SCNet(**config.model) elif model_type == 'apollo': from models.look2hear.models import BaseModel model = BaseModel.apollo(**config.model) elif model_type == 'bs_mamba2': from models.ts_bs_mamba2 import Separator model = Separator(**config.model) elif model_type == 'experimental_mdx23c_stht': from models.mdx23c_tfc_tdf_v3_with_STHT import TFC_TDF_net model = TFC_TDF_net(config) else: raise ValueError(f"Unknown model type: {model_type}") return model, config def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]: """ Reads an audio file, ensuring mono audio is converted to two-dimensional format, and transposes the data to have channels as the first dimension. Parameters ---------- path : str Path to the audio file. skip_err: bool If true, not raise errors instr: name of instument Returns ------- Tuple[np.ndarray, int] A tuple containing: - Transposed audio data as a NumPy array with shape (channels, length). For mono audio, the shape will be (1, length). - Sampling rate (int), e.g., 44100. """ try: mix, sr = sf.read(path) except Exception as e: if skip_err: print(f"No stem {instr}: skip!") return None, None else: raise RuntimeError(f"Error reading the file at {path}: {e}") else: if len(mix.shape) == 1: # For mono audio mix = np.expand_dims(mix, axis=-1) return mix.T, sr def normalize_audio(audio: np.ndarray) -> tuple[np.ndarray, Dict[str, float]]: """ Normalize an audio signal by subtracting the mean and dividing by the standard deviation. Parameters: ---------- audio : np.ndarray Input audio array with shape (channels, time) or (time,). Returns: ------- tuple[np.ndarray, dict[str, float]] - Normalized audio array with the same shape as the input. - Dictionary containing the mean and standard deviation of the original audio. """ mono = audio.mean(0) mean, std = mono.mean(), mono.std() return (audio - mean) / std, {"mean": mean, "std": std} def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray: """ Denormalize an audio signal by reversing the normalization process (multiplying by the standard deviation and adding the mean). Parameters: ---------- audio : np.ndarray Normalized audio array to be denormalized. norm_params : dict[str, float] Dictionary containing the 'mean' and 'std' values used for normalization. Returns: ------- np.ndarray Denormalized audio array with the same shape as the input. """ return audio * norm_params["std"] + norm_params["mean"] def apply_tta( config, model: torch.nn.Module, mix: torch.Tensor, waveforms_orig: Dict[str, torch.Tensor], device: torch.device, model_type: str ) -> Dict[str, torch.Tensor]: """ Apply Test-Time Augmentation (TTA) for source separation. This function processes the input mixture with test-time augmentations, including channel inversion and polarity inversion, to enhance the separation results. The results from all augmentations are averaged to produce the final output. Parameters: ---------- config : Any Configuration object containing model and processing parameters. model : torch.nn.Module The trained model used for source separation. mix : torch.Tensor The mixed audio tensor with shape (channels, time). waveforms_orig : Dict[str, torch.Tensor] Dictionary of original separated waveforms (before TTA) for each instrument. device : torch.device Device (CPU or CUDA) on which the model will be executed. model_type : str Type of the model being used (e.g., "demucs", "custom_model"). Returns: ------- Dict[str, torch.Tensor] Updated dictionary of separated waveforms after applying TTA. """ # Create augmentations: channel inversion and polarity inversion track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()] # Process each augmented mixture for i, augmented_mix in enumerate(track_proc_list): waveforms = demix(config, model, augmented_mix, device, model_type=model_type) for el in waveforms: if i == 0: waveforms_orig[el] += waveforms[el][::-1].copy() else: waveforms_orig[el] -= waveforms[el] # Average the results across augmentations for el in waveforms_orig: waveforms_orig[el] /= len(track_proc_list) + 1 return waveforms_orig def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: """ Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end. This function creates a window of size `window_size` where the first `fade_size` elements linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease from 1 to 0 (fade-out). The middle part of the window is filled with ones. Parameters: ---------- window_size : int The total size of the window. fade_size : int The size of the fade-in and fade-out regions. Returns: ------- torch.Tensor A tensor of shape (window_size,) containing the generated windowing array. Example: ------- If `window_size=10` and `fade_size=3`, the output will be: tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000]) """ fadein = torch.linspace(0, 1, fade_size) fadeout = torch.linspace(1, 0, fade_size) window = torch.ones(window_size) window[-fade_size:] = fadeout window[:fade_size] = fadein return window def demix( config: ConfigDict, model: torch.nn.Module, mix: torch.Tensor, device: torch.device, model_type: str, pbar: bool = False ) -> Tuple[List[Dict[str, np.ndarray]], np.ndarray]: """ Unified function for audio source separation with support for multiple processing modes. This function separates audio into its constituent sources using either a generic custom logic or a Demucs-specific logic. It supports batch processing and overlapping window-based chunking for efficient and artifact-free separation. Parameters: ---------- config : ConfigDict Configuration object containing audio and inference settings. model : torch.nn.Module The trained model used for audio source separation. mix : torch.Tensor Input audio tensor with shape (channels, time). device : torch.device The computation device (CPU or CUDA). model_type : str, optional Processing mode: - "demucs" for logic specific to the Demucs model. Default is "generic". pbar : bool, optional If True, displays a progress bar during chunk processing. Default is False. Returns: ------- Union[Dict[str, np.ndarray], np.ndarray] - A dictionary mapping target instruments to separated audio sources if multiple instruments are present. - A numpy array of the separated source if only one instrument is present. """ mix = torch.tensor(mix, dtype=torch.float32) if model_type == 'htdemucs': mode = 'demucs' else: mode = 'generic' # Define processing parameters based on the mode if mode == 'demucs': chunk_size = config.training.samplerate * config.training.segment num_instruments = len(config.training.instruments) num_overlap = config.inference.num_overlap step = chunk_size // num_overlap else: chunk_size = config.audio.chunk_size num_instruments = len(prefer_target_instrument(config)) num_overlap = config.inference.num_overlap fade_size = chunk_size // 10 step = chunk_size // num_overlap border = chunk_size - step length_init = mix.shape[-1] windowing_array = _getWindowingArray(chunk_size, fade_size) # Add padding for generic mode to handle edge artifacts if length_init > 2 * border and border > 0: mix = nn.functional.pad(mix, (border, border), mode="reflect") batch_size = config.inference.batch_size use_amp = getattr(config.training, 'use_amp', True) with torch.cuda.amp.autocast(enabled=use_amp): with torch.inference_mode(): # Initialize result and counter tensors req_shape = (num_instruments,) + mix.shape result = torch.zeros(req_shape, dtype=torch.float32) counter = torch.zeros(req_shape, dtype=torch.float32) i = 0 batch_data = [] batch_locations = [] progress_bar = tqdm( total=mix.shape[1], desc="Processing audio chunks", leave=False ) if pbar else None while i < mix.shape[1]: # Extract chunk and apply padding if necessary part = mix[:, i:i + chunk_size].to(device) chunk_len = part.shape[-1] if mode == "generic" and chunk_len > chunk_size // 2: pad_mode = "reflect" else: pad_mode = "constant" part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) batch_data.append(part) batch_locations.append((i, chunk_len)) i += step # Process batch if it's full or the end is reached if len(batch_data) >= batch_size or i >= mix.shape[1]: arr = torch.stack(batch_data, dim=0) x = model(arr) if mode == "generic": window = windowing_array.clone() # using clone() fixes the clicks at chunk edges when using batch_size=1 if i - step == 0: # First audio chunk, no fadein window[:fade_size] = 1 elif i >= mix.shape[1]: # Last audio chunk, no fadeout window[-fade_size:] = 1 for j, (start, seg_len) in enumerate(batch_locations): if mode == "generic": result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len] counter[..., start:start + seg_len] += window[..., :seg_len] else: result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() counter[..., start:start + seg_len] += 1.0 batch_data.clear() batch_locations.clear() if progress_bar: progress_bar.update(step) if progress_bar: progress_bar.close() # Compute final estimated sources estimated_sources = result / counter estimated_sources = estimated_sources.cpu().numpy() np.nan_to_num(estimated_sources, copy=False, nan=0.0) # Remove padding for generic mode if mode == "generic": if length_init > 2 * border and border > 0: estimated_sources = estimated_sources[..., border:-border] # Return the result as a dictionary or a single array if mode == "demucs": instruments = config.training.instruments else: instruments = prefer_target_instrument(config) ret_data = {k: v for k, v in zip(instruments, estimated_sources)} if mode == "demucs" and num_instruments <= 1: return estimated_sources else: return ret_data def prefer_target_instrument(config: ConfigDict) -> List[str]: """ Return the list of target instruments based on the configuration. If a specific target instrument is specified in the configuration, it returns a list with that instrument. Otherwise, it returns the list of instruments. Parameters: ---------- config : ConfigDict Configuration object containing the list of instruments or the target instrument. Returns: ------- List[str] A list of target instruments. """ if getattr(config.training, 'target_instrument', None): return [config.training.target_instrument] else: return config.training.instruments def load_not_compatible_weights(model: torch.nn.Module, weights: str, verbose: bool = False) -> None: """ Load weights into a model, handling mismatched shapes and dimensions. Args: model: PyTorch model into which the weights will be loaded. weights: Path to the weights file. verbose: If True, prints detailed information about matching and mismatched layers. """ new_model = model.state_dict() old_model = torch.load(weights) if 'state' in old_model: # Fix for htdemucs weights loading old_model = old_model['state'] if 'state_dict' in old_model: # Fix for apollo weights loading old_model = old_model['state_dict'] for el in new_model: if el in old_model: if verbose: print(f'Match found for {el}!') if new_model[el].shape == old_model[el].shape: if verbose: print('Action: Just copy weights!') new_model[el] = old_model[el] else: if len(new_model[el].shape) != len(old_model[el].shape): if verbose: print('Action: Different dimension! Too lazy to write the code... Skip it') else: if verbose: print(f'Shape is different: {tuple(new_model[el].shape)} != {tuple(old_model[el].shape)}') ln = len(new_model[el].shape) max_shape = [] slices_old = [] slices_new = [] for i in range(ln): max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i])) slices_old.append(slice(0, old_model[el].shape[i])) slices_new.append(slice(0, new_model[el].shape[i])) # print(max_shape) # print(slices_old, slices_new) slices_old = tuple(slices_old) slices_new = tuple(slices_new) max_matrix = np.zeros(max_shape, dtype=np.float32) for i in range(ln): max_matrix[slices_old] = old_model[el].cpu().numpy() max_matrix = torch.from_numpy(max_matrix) new_model[el] = max_matrix[slices_new] else: if verbose: print(f'Match not found for {el}!') model.load_state_dict( new_model ) def load_lora_weights(model: torch.nn.Module, lora_path: str, device: str = 'cpu') -> None: """ Load LoRA weights into a model. This function updates the given model with LoRA-specific weights from the specified checkpoint file. It does not require the checkpoint to match the model's full state dictionary, as only LoRA layers are updated. Parameters: ---------- model : Module The PyTorch model into which the LoRA weights will be loaded. lora_path : str Path to the LoRA checkpoint file. device : str, optional The device to load the weights onto, by default 'cpu'. Common values are 'cpu' or 'cuda'. Returns: ------- None The model is updated in place. """ lora_state_dict = torch.load(lora_path, map_location=device) model.load_state_dict(lora_state_dict, strict=False) def load_start_checkpoint(args: argparse.Namespace, model: torch.nn.Module, type_='train') -> None: """ Load the starting checkpoint for a model. Args: args: Parsed command-line arguments containing the checkpoint path. model: PyTorch model to load the checkpoint into. type_: how to load weights - for train we can load not fully compatible weights """ print(f'Start from checkpoint: {args.start_check_point}') if type_ in ['train']: if 1: load_not_compatible_weights(model, args.start_check_point, verbose=False) else: model.load_state_dict(torch.load(args.start_check_point)) else: device='cpu' if args.model_type in ['htdemucs', 'apollo']: state_dict = torch.load(args.start_check_point, map_location=device, weights_only=False) # Fix for htdemucs pretrained models if 'state' in state_dict: state_dict = state_dict['state'] # Fix for apollo pretrained models if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] else: state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True) model.load_state_dict(state_dict) if args.lora_checkpoint: print(f"Loading LoRA weights from: {args.lora_checkpoint}") load_lora_weights(model, args.lora_checkpoint) def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module: """ Replaces specific layers in the model with LoRA-extended versions. Parameters: ---------- config : Dict[str, Any] Configuration containing parameters for LoRA. It should include a 'lora' key with parameters for `MergedLinear`. model : nn.Module The original model in which the layers will be replaced. Returns: ------- nn.Module The modified model with the replaced layers. """ if 'lora' not in config: raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.") replaced_layers = 0 # Counter for replaced layers for name, module in model.named_modules(): hierarchy = name.split('.') layer_name = hierarchy[-1] # Check if this is the target layer to replace (and layer_name == 'to_qkv') if isinstance(module, nn.Linear): try: # Get the parent module parent_module = model for submodule_name in hierarchy[:-1]: parent_module = getattr(parent_module, submodule_name) # Replace the module with LoRA-enabled layer setattr( parent_module, layer_name, lora.MergedLinear( in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, **config['lora'] ) ) replaced_layers += 1 # Increment the counter except Exception as e: print(f"Error replacing layer {name}: {e}") if replaced_layers == 0: print("Warning: No layers were replaced. Check the model structure and configuration.") else: print(f"Number of layers replaced with LoRA: {replaced_layers}") return model def draw_spectrogram(waveform, sample_rate, length, output_file): import librosa.display # Cut only required part of spectorgram x = waveform[:int(length * sample_rate), :] X = librosa.stft(x.mean(axis=-1)) # perform short-term fourier transform on mono signal Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) # convert an amplitude spectrogram to dB-scaled spectrogram. fig, ax = plt.subplots() # plt.figure(figsize=(30, 10)) # initialize the fig size img = librosa.display.specshow( Xdb, cmap='plasma', sr=sample_rate, x_axis='time', y_axis='linear', ax=ax ) ax.set(title='File: ' + os.path.basename(output_file)) fig.colorbar(img, ax=ax, format="%+2.f dB") if output_file is not None: plt.savefig(output_file)