Spaces:
Build error
Build error
| import math | |
| import torch | |
| import librosa | |
| # based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py | |
| class RandomMelProjection(torch.nn.Module): | |
| def __init__( | |
| self, | |
| sample_rate, | |
| embed_dim=4096, | |
| n_mels=128, | |
| n_fft=4096, | |
| hop_size=1024, | |
| seed=0, | |
| epsilon=1e-4, | |
| ): | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.embed_dim = embed_dim | |
| self.n_mels = n_mels | |
| self.n_fft = n_fft | |
| self.hop_size = hop_size | |
| self.seed = seed | |
| self.epsilon = epsilon | |
| # Set random seed | |
| torch.random.manual_seed(self.seed) | |
| # Create a Hann window buffer to apply to frames prior to FFT. | |
| self.register_buffer("window", torch.hann_window(self.n_fft)) | |
| # Create a mel filter buffer. | |
| mel_scale = torch.tensor( | |
| librosa.filters.mel( | |
| self.sample_rate, | |
| n_fft=self.n_fft, | |
| n_mels=self.n_mels, | |
| ) | |
| ) | |
| self.register_buffer("mel_scale", mel_scale) | |
| # Projection matrices. | |
| normalization = math.sqrt(self.n_mels) | |
| self.projection = torch.nn.Parameter( | |
| torch.rand(self.n_mels, self.embed_dim) / normalization, | |
| requires_grad=False, | |
| ) | |
| def forward(self, x): | |
| bs, chs, samp = x.size() | |
| x = torch.stft( | |
| x.view(bs, -1), | |
| self.n_fft, | |
| self.hop_size, | |
| window=self.window, | |
| return_complex=True, | |
| ) | |
| x = x.unsqueeze(1).permute(0, 1, 3, 2) | |
| # Apply the mel-scale filter to the power spectrum. | |
| x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1)) | |
| # power scale | |
| x = torch.pow(x + self.epsilon, 0.3) | |
| # apply random projection | |
| e = x.matmul(self.projection) | |
| # take mean across temporal dim | |
| e = e.mean(dim=2).view(bs, -1) | |
| return e | |
| def compute_frame_embedding(self, x): | |
| # Compute the real-valued Fourier transform on windowed input signal. | |
| x = torch.fft.rfft(x * self.window) | |
| # Convert to a power spectrum. | |
| x = torch.abs(x) ** 2.0 | |
| # Apply the mel-scale filter to the power spectrum. | |
| x = torch.matmul(x, self.mel_scale.transpose(0, 1)) | |
| # Convert to a log mel spectrum. | |
| x = torch.log(x + self.epsilon) | |
| # Apply projection to get a 4096 dimension embedding | |
| embedding = x.matmul(self.projection) | |
| return embedding | |