import librosa import numpy as np import torch import torchaudio import torchcrepe from torchcrepe.loudness import REF_DB SILENCE_THRESHOLD = -60 UNVOICED_THRESHOLD = 0.21 """ Periodicity metrics adapted from https://github.com/descriptinc/cargan """ def predict_pitch( audio: torch.Tensor, silence_threshold: float = SILENCE_THRESHOLD, unvoiced_treshold: float = UNVOICED_THRESHOLD ): """ Predicts pitch and periodicity for the given audio. Args: audio (Tensor): The audio waveform. silence_threshold (float): The threshold for silence detection. unvoiced_treshold (float): The threshold for unvoiced detection. Returns: pitch (ndarray): The predicted pitch. periodicity (ndarray): The predicted periodicity. """ # torchcrepe inference pitch, periodicity = torchcrepe.predict( audio, fmin=50.0, fmax=550, sample_rate=torchcrepe.SAMPLE_RATE, model="full", return_periodicity=True, device=audio.device, pad=False, ) pitch = pitch.cpu().numpy() periodicity = periodicity.cpu().numpy() # Calculate dB-scaled spectrogram and set low energy frames to unvoiced hop_length = torchcrepe.SAMPLE_RATE // 100 # default CREPE stft = torchaudio.functional.spectrogram( audio, window=torch.hann_window(torchcrepe.WINDOW_SIZE, device=audio.device), n_fft=torchcrepe.WINDOW_SIZE, hop_length=hop_length, win_length=torchcrepe.WINDOW_SIZE, power=2, normalized=False, pad=0, center=False, ) # Perceptual weighting freqs = librosa.fft_frequencies(sr=torchcrepe.SAMPLE_RATE, n_fft=torchcrepe.WINDOW_SIZE) perceptual_stft = librosa.perceptual_weighting(stft.cpu().numpy(), freqs) - REF_DB silence = perceptual_stft.mean(axis=1) < silence_threshold periodicity[silence] = 0 pitch[periodicity < unvoiced_treshold] = torchcrepe.UNVOICED return pitch, periodicity def calculate_periodicity_metrics(y: torch.Tensor, y_hat: torch.Tensor): """ Calculates periodicity metrics for the predicted and true audio data. Args: y (Tensor): The true audio data. y_hat (Tensor): The predicted audio data. Returns: periodicity_loss (float): The periodicity loss. pitch_loss (float): The pitch loss. f1 (float): The F1 score for voiced/unvoiced classification """ true_pitch, true_periodicity = predict_pitch(y) pred_pitch, pred_periodicity = predict_pitch(y_hat) true_voiced = ~np.isnan(true_pitch) pred_voiced = ~np.isnan(pred_pitch) periodicity_loss = np.sqrt(((pred_periodicity - true_periodicity) ** 2).mean(axis=1)).mean() # Update pitch rmse voiced = true_voiced & pred_voiced difference_cents = 1200 * (np.log2(true_pitch[voiced]) - np.log2(pred_pitch[voiced])) pitch_loss = np.sqrt((difference_cents ** 2).mean()) # voiced/unvoiced precision and recall true_positives = (true_voiced & pred_voiced).sum() false_positives = (~true_voiced & pred_voiced).sum() false_negatives = (true_voiced & ~pred_voiced).sum() precision = true_positives / (true_positives + false_positives) recall = true_positives / (true_positives + false_negatives) f1 = 2 * precision * recall / (precision + recall) return periodicity_loss, pitch_loss, f1