import librosa import numpy as np import torch import crepe ############################################################################### # Probability sequence decoding methods ############################################################################### def argmax(logits): """Sample observations by taking the argmax""" bins = logits.argmax(dim=1) # Convert to frequency in Hz return bins, crepe.convert.bins_to_frequency(bins) def weighted_argmax(logits): """Sample observations using weighted sum near the argmax""" # Find center of analysis window bins = logits.argmax(dim=1) # Find bounds of analysis window start = torch.max(torch.tensor(0, device=logits.device), bins - 4) end = torch.min(torch.tensor(logits.size(1), device=logits.device), bins + 5) # Mask out everything outside of window for batch in range(logits.size(0)): for time in range(logits.size(2)): logits[batch, :start[batch, time], time] = -float('inf') logits[batch, end[batch, time]:, time] = -float('inf') # Construct weights if not hasattr(weighted_argmax, 'weights'): weights = crepe.convert.bins_to_cents(torch.arange(360)) weighted_argmax.weights = weights[None, :, None] # Ensure devices are the same (no-op if they are) weighted_argmax.weights = weighted_argmax.weights.to(logits.device) # Convert to probabilities with torch.no_grad(): probs = torch.sigmoid(logits) # Apply weights cents = (weighted_argmax.weights * probs).sum(dim=1) / probs.sum(dim=1) # Convert to frequency in Hz return bins, crepe.convert.cents_to_frequency(cents) def viterbi(logits): """Sample observations using viterbi decoding""" # Create viterbi transition matrix if not hasattr(viterbi, 'transition'): xx, yy = np.meshgrid(range(360), range(360)) transition = np.maximum(12 - abs(xx - yy), 0) transition = transition / transition.sum(axis=1, keepdims=True) viterbi.transition = transition # Normalize logits with torch.no_grad(): probs = torch.nn.functional.softmax(logits, dim=1) # Convert to numpy sequences = probs.cpu().numpy() # Perform viterbi decoding bins = np.array([ librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) for sequence in sequences]) # Convert to pytorch bins = torch.tensor(bins, device=probs.device) # Convert to frequency in Hz return bins, crepe.convert.bins_to_frequency(bins)