File size: 3,407 Bytes
6742cf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import librosa
import numpy as np
import torch
import torchaudio
import torchcrepe
from torchcrepe.loudness import REF_DB
SILENCE_THRESHOLD = -60
UNVOICED_THRESHOLD = 0.21
"""
Periodicity metrics adapted from https://github.com/descriptinc/cargan
"""
def predict_pitch(
audio: torch.Tensor, silence_threshold: float = SILENCE_THRESHOLD, unvoiced_treshold: float = UNVOICED_THRESHOLD
):
"""
Predicts pitch and periodicity for the given audio.
Args:
audio (Tensor): The audio waveform.
silence_threshold (float): The threshold for silence detection.
unvoiced_treshold (float): The threshold for unvoiced detection.
Returns:
pitch (ndarray): The predicted pitch.
periodicity (ndarray): The predicted periodicity.
"""
# torchcrepe inference
pitch, periodicity = torchcrepe.predict(
audio,
fmin=50.0,
fmax=550,
sample_rate=torchcrepe.SAMPLE_RATE,
model="full",
return_periodicity=True,
device=audio.device,
pad=False,
)
pitch = pitch.cpu().numpy()
periodicity = periodicity.cpu().numpy()
# Calculate dB-scaled spectrogram and set low energy frames to unvoiced
hop_length = torchcrepe.SAMPLE_RATE // 100 # default CREPE
stft = torchaudio.functional.spectrogram(
audio,
window=torch.hann_window(torchcrepe.WINDOW_SIZE, device=audio.device),
n_fft=torchcrepe.WINDOW_SIZE,
hop_length=hop_length,
win_length=torchcrepe.WINDOW_SIZE,
power=2,
normalized=False,
pad=0,
center=False,
)
# Perceptual weighting
freqs = librosa.fft_frequencies(sr=torchcrepe.SAMPLE_RATE, n_fft=torchcrepe.WINDOW_SIZE)
perceptual_stft = librosa.perceptual_weighting(stft.cpu().numpy(), freqs) - REF_DB
silence = perceptual_stft.mean(axis=1) < silence_threshold
periodicity[silence] = 0
pitch[periodicity < unvoiced_treshold] = torchcrepe.UNVOICED
return pitch, periodicity
def calculate_periodicity_metrics(y: torch.Tensor, y_hat: torch.Tensor):
"""
Calculates periodicity metrics for the predicted and true audio data.
Args:
y (Tensor): The true audio data.
y_hat (Tensor): The predicted audio data.
Returns:
periodicity_loss (float): The periodicity loss.
pitch_loss (float): The pitch loss.
f1 (float): The F1 score for voiced/unvoiced classification
"""
true_pitch, true_periodicity = predict_pitch(y)
pred_pitch, pred_periodicity = predict_pitch(y_hat)
true_voiced = ~np.isnan(true_pitch)
pred_voiced = ~np.isnan(pred_pitch)
periodicity_loss = np.sqrt(((pred_periodicity - true_periodicity) ** 2).mean(axis=1)).mean()
# Update pitch rmse
voiced = true_voiced & pred_voiced
difference_cents = 1200 * (np.log2(true_pitch[voiced]) - np.log2(pred_pitch[voiced]))
pitch_loss = np.sqrt((difference_cents ** 2).mean())
# voiced/unvoiced precision and recall
true_positives = (true_voiced & pred_voiced).sum()
false_positives = (~true_voiced & pred_voiced).sum()
false_negatives = (true_voiced & ~pred_voiced).sum()
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
f1 = 2 * precision * recall / (precision + recall)
return periodicity_loss, pitch_loss, f1
|