Spaces:
Paused
Paused
| import os | |
| import datetime | |
| import json | |
| import logging | |
| import librosa | |
| import pickle | |
| from typing import Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from models.audiosep import AudioSep, get_model_class | |
| def ignore_warnings(): | |
| import warnings | |
| # Ignore UserWarning from torch.meshgrid | |
| warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional') | |
| # Refined regex pattern to capture variations in the warning message | |
| pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*" | |
| warnings.filterwarnings('ignore', message=pattern) | |
| def create_logging(log_dir, filemode): | |
| os.makedirs(log_dir, exist_ok=True) | |
| i1 = 0 | |
| while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): | |
| i1 += 1 | |
| log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", | |
| datefmt="%a, %d %b %Y %H:%M:%S", | |
| filename=log_path, | |
| filemode=filemode, | |
| ) | |
| # Print to console | |
| console = logging.StreamHandler() | |
| console.setLevel(logging.INFO) | |
| formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") | |
| console.setFormatter(formatter) | |
| logging.getLogger("").addHandler(console) | |
| return logging | |
| def float32_to_int16(x: float) -> int: | |
| x = np.clip(x, a_min=-1, a_max=1) | |
| return (x * 32767.0).astype(np.int16) | |
| def int16_to_float32(x: int) -> float: | |
| return (x / 32767.0).astype(np.float32) | |
| def parse_yaml(config_yaml: str) -> Dict: | |
| r"""Parse yaml file. | |
| Args: | |
| config_yaml (str): config yaml path | |
| Returns: | |
| yaml_dict (Dict): parsed yaml file | |
| """ | |
| with open(config_yaml, "r") as fr: | |
| return yaml.load(fr, Loader=yaml.FullLoader) | |
| def get_audioset632_id_to_lb(ontology_path: str) -> Dict: | |
| r"""Get AudioSet 632 classes ID to label mapping.""" | |
| audioset632_id_to_lb = {} | |
| with open(ontology_path) as f: | |
| data_list = json.load(f) | |
| for e in data_list: | |
| audioset632_id_to_lb[e["id"]] = e["name"] | |
| return audioset632_id_to_lb | |
| def load_pretrained_panns( | |
| model_type: str, | |
| checkpoint_path: str, | |
| freeze: bool | |
| ) -> nn.Module: | |
| r"""Load pretrained pretrained audio neural networks (PANNs). | |
| Args: | |
| model_type: str, e.g., "Cnn14" | |
| checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth" | |
| freeze: bool | |
| Returns: | |
| model: nn.Module | |
| """ | |
| if model_type == "Cnn14": | |
| Model = Cnn14 | |
| elif model_type == "Cnn14_DecisionLevelMax": | |
| Model = Cnn14_DecisionLevelMax | |
| else: | |
| raise NotImplementedError | |
| model = Model(sample_rate=32000, window_size=1024, hop_size=320, | |
| mel_bins=64, fmin=50, fmax=14000, classes_num=527) | |
| if checkpoint_path: | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| model.load_state_dict(checkpoint["model"]) | |
| if freeze: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |
| def energy(x): | |
| return torch.mean(x ** 2) | |
| def magnitude_to_db(x): | |
| eps = 1e-10 | |
| return 20. * np.log10(max(x, eps)) | |
| def db_to_magnitude(x): | |
| return 10. ** (x / 20) | |
| def ids_to_hots(ids, classes_num, device): | |
| hots = torch.zeros(classes_num).to(device) | |
| for id in ids: | |
| hots[id] = 1 | |
| return hots | |
| def calculate_sdr( | |
| ref: np.ndarray, | |
| est: np.ndarray, | |
| eps=1e-10 | |
| ) -> float: | |
| r"""Calculate SDR between reference and estimation. | |
| Args: | |
| ref (np.ndarray), reference signal | |
| est (np.ndarray), estimated signal | |
| """ | |
| reference = ref | |
| noise = est - reference | |
| numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None) | |
| denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) | |
| sdr = 10. * np.log10(numerator / denominator) | |
| return sdr | |
| def calculate_sisdr(ref, est): | |
| r"""Calculate SDR between reference and estimation. | |
| Args: | |
| ref (np.ndarray), reference signal | |
| est (np.ndarray), estimated signal | |
| """ | |
| eps = np.finfo(ref.dtype).eps | |
| reference = ref.copy() | |
| estimate = est.copy() | |
| reference = reference.reshape(reference.size, 1) | |
| estimate = estimate.reshape(estimate.size, 1) | |
| Rss = np.dot(reference.T, reference) | |
| # get the scaling factor for clean sources | |
| a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) | |
| e_true = a * reference | |
| e_res = estimate - e_true | |
| Sss = (e_true**2).sum() | |
| Snn = (e_res**2).sum() | |
| sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) | |
| return sisdr | |
| class StatisticsContainer(object): | |
| def __init__(self, statistics_path): | |
| self.statistics_path = statistics_path | |
| self.backup_statistics_path = "{}_{}.pkl".format( | |
| os.path.splitext(self.statistics_path)[0], | |
| datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), | |
| ) | |
| self.statistics_dict = {"balanced_train": [], "test": []} | |
| def append(self, steps, statistics, split, flush=True): | |
| statistics["steps"] = steps | |
| self.statistics_dict[split].append(statistics) | |
| if flush: | |
| self.flush() | |
| def flush(self): | |
| pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) | |
| pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) | |
| logging.info(" Dump statistics to {}".format(self.statistics_path)) | |
| logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) | |
| def get_mean_sdr_from_dict(sdris_dict): | |
| mean_sdr = np.nanmean(list(sdris_dict.values())) | |
| return mean_sdr | |
| def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray: | |
| r"""Remove silent frames.""" | |
| window_size = int(sample_rate * 0.1) | |
| threshold = 0.02 | |
| frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T | |
| # shape: (frames_num, window_size) | |
| new_frames = get_active_frames(frames, threshold) | |
| # shape: (new_frames_num, window_size) | |
| new_audio = new_frames.flatten() | |
| # shape: (new_audio_samples,) | |
| return new_audio | |
| def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray: | |
| r"""Get active frames.""" | |
| energy = np.max(np.abs(frames), axis=-1) | |
| # shape: (frames_num,) | |
| active_indexes = np.where(energy > threshold)[0] | |
| # shape: (new_frames_num,) | |
| new_frames = frames[active_indexes] | |
| # shape: (new_frames_num,) | |
| return new_frames | |
| def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray: | |
| r"""Repeat audio to length.""" | |
| repeats_num = (segment_samples // audio.shape[-1]) + 1 | |
| audio = np.tile(audio, repeats_num)[0 : segment_samples] | |
| return audio | |
| def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False): | |
| min_len = min(ref.shape[-1], est.shape[-1]) | |
| pointer = 0 | |
| sdrs = [] | |
| while pointer + hop_samples < min_len: | |
| sdr = calculate_sdr( | |
| ref=ref[:, pointer : pointer + hop_samples], | |
| est=est[:, pointer : pointer + hop_samples], | |
| ) | |
| sdrs.append(sdr) | |
| pointer += hop_samples | |
| sdr = np.nanmedian(sdrs) | |
| if return_sdr_list: | |
| return sdr, sdrs | |
| else: | |
| return sdr | |
| def loudness(data, input_loudness, target_loudness): | |
| """ Loudness normalize a signal. | |
| Normalize an input signal to a user loudness in dB LKFS. | |
| Params | |
| ------- | |
| data : torch.Tensor | |
| Input multichannel audio data. | |
| input_loudness : float | |
| Loudness of the input in dB LUFS. | |
| target_loudness : float | |
| Target loudness of the output in dB LUFS. | |
| Returns | |
| ------- | |
| output : torch.Tensor | |
| Loudness normalized output data. | |
| """ | |
| # calculate the gain needed to scale to the desired loudness level | |
| delta_loudness = target_loudness - input_loudness | |
| gain = torch.pow(10.0, delta_loudness / 20.0) | |
| output = gain * data | |
| # check for potentially clipped samples | |
| # if torch.max(torch.abs(output)) >= 1.0: | |
| # warnings.warn("Possible clipped samples in output.") | |
| return output | |
| def load_ss_model( | |
| configs: Dict, | |
| checkpoint_path: str, | |
| query_encoder: nn.Module | |
| ) -> nn.Module: | |
| r"""Load trained universal source separation model. | |
| Args: | |
| configs (Dict) | |
| checkpoint_path (str): path of the checkpoint to load | |
| device (str): e.g., "cpu" | "cuda" | |
| Returns: | |
| pl_model: pl.LightningModule | |
| """ | |
| ss_model_type = configs["model"]["model_type"] | |
| input_channels = configs["model"]["input_channels"] | |
| output_channels = configs["model"]["output_channels"] | |
| condition_size = configs["model"]["condition_size"] | |
| # Initialize separation model | |
| SsModel = get_model_class(model_type=ss_model_type) | |
| ss_model = SsModel( | |
| input_channels=input_channels, | |
| output_channels=output_channels, | |
| condition_size=condition_size, | |
| ) | |
| # Load PyTorch Lightning model | |
| pl_model = AudioSep.load_from_checkpoint( | |
| checkpoint_path=checkpoint_path, | |
| strict=False, | |
| ss_model=ss_model, | |
| waveform_mixer=None, | |
| query_encoder=query_encoder, | |
| loss_function=None, | |
| optimizer_type=None, | |
| learning_rate=None, | |
| lr_lambda_func=None, | |
| map_location=torch.device('cpu'), | |
| ) | |
| return pl_model | |
| def parse_yaml(config_yaml: str) -> Dict: | |
| r"""Parse yaml file. | |
| Args: | |
| config_yaml (str): config yaml path | |
| Returns: | |
| yaml_dict (Dict): parsed yaml file | |
| """ | |
| with open(config_yaml, "r") as fr: | |
| return yaml.load(fr, Loader=yaml.FullLoader) |