Spaces:
Running
Running
import librosa | |
import numpy as np | |
import torch | |
import crepe | |
############################################################################### | |
# Probability sequence decoding methods | |
############################################################################### | |
def argmax(logits): | |
"""Sample observations by taking the argmax""" | |
bins = logits.argmax(dim=1) | |
# Convert to frequency in Hz | |
return bins, crepe.convert.bins_to_frequency(bins) | |
def weighted_argmax(logits): | |
"""Sample observations using weighted sum near the argmax""" | |
# Find center of analysis window | |
bins = logits.argmax(dim=1) | |
# Find bounds of analysis window | |
start = torch.max(torch.tensor(0, device=logits.device), bins - 4) | |
end = torch.min(torch.tensor(logits.size(1), device=logits.device), bins + 5) | |
# Mask out everything outside of window | |
for batch in range(logits.size(0)): | |
for time in range(logits.size(2)): | |
logits[batch, :start[batch, time], time] = -float('inf') | |
logits[batch, end[batch, time]:, time] = -float('inf') | |
# Construct weights | |
if not hasattr(weighted_argmax, 'weights'): | |
weights = crepe.convert.bins_to_cents(torch.arange(360)) | |
weighted_argmax.weights = weights[None, :, None] | |
# Ensure devices are the same (no-op if they are) | |
weighted_argmax.weights = weighted_argmax.weights.to(logits.device) | |
# Convert to probabilities | |
with torch.no_grad(): | |
probs = torch.sigmoid(logits) | |
# Apply weights | |
cents = (weighted_argmax.weights * probs).sum(dim=1) / probs.sum(dim=1) | |
# Convert to frequency in Hz | |
return bins, crepe.convert.cents_to_frequency(cents) | |
def viterbi(logits): | |
"""Sample observations using viterbi decoding""" | |
# Create viterbi transition matrix | |
if not hasattr(viterbi, 'transition'): | |
xx, yy = np.meshgrid(range(360), range(360)) | |
transition = np.maximum(12 - abs(xx - yy), 0) | |
transition = transition / transition.sum(axis=1, keepdims=True) | |
viterbi.transition = transition | |
# Normalize logits | |
with torch.no_grad(): | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
# Convert to numpy | |
sequences = probs.cpu().numpy() | |
# Perform viterbi decoding | |
bins = np.array([ | |
librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) | |
for sequence in sequences]) | |
# Convert to pytorch | |
bins = torch.tensor(bins, device=probs.device) | |
# Convert to frequency in Hz | |
return bins, crepe.convert.bins_to_frequency(bins) | |