|  |  | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | import subprocess | 
					
						
						|  | from functools import lru_cache | 
					
						
						|  | from typing import Optional, Union | 
					
						
						|  | from scipy.io.wavfile import write | 
					
						
						|  | import tempfile | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | def exact_div(x, y): | 
					
						
						|  | assert x % y == 0 | 
					
						
						|  | return x // y | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | SAMPLE_RATE = 16000 | 
					
						
						|  | N_FFT = 400 | 
					
						
						|  | HOP_LENGTH = 160 | 
					
						
						|  | CHUNK_LENGTH = 30 | 
					
						
						|  | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE | 
					
						
						|  | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) | 
					
						
						|  |  | 
					
						
						|  | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 | 
					
						
						|  | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) | 
					
						
						|  | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray: | 
					
						
						|  | """ | 
					
						
						|  | Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | file: Union[str, np.ndarray] | 
					
						
						|  | The audio file to open or a numpy array containing the audio data. | 
					
						
						|  |  | 
					
						
						|  | sr: int | 
					
						
						|  | The sample rate to resample the audio if necessary. | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | A NumPy array containing the audio waveform, in float32 dtype. | 
					
						
						|  | """ | 
					
						
						|  | if isinstance(file, np.ndarray): | 
					
						
						|  | if file.dtype != np.float32: | 
					
						
						|  | file = file.astype(np.float32) | 
					
						
						|  | if file.ndim > 1: | 
					
						
						|  | file = np.mean(file, axis=1) | 
					
						
						|  |  | 
					
						
						|  | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | 
					
						
						|  | write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16)) | 
					
						
						|  | temp_file_path = temp_file.name | 
					
						
						|  | temp_file.close() | 
					
						
						|  | else: | 
					
						
						|  | temp_file_path = file | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | cmd = [ | 
					
						
						|  | "ffmpeg", | 
					
						
						|  | "-nostdin", | 
					
						
						|  | "-threads", | 
					
						
						|  | "0", | 
					
						
						|  | "-i", | 
					
						
						|  | temp_file_path, | 
					
						
						|  | "-f", | 
					
						
						|  | "s16le", | 
					
						
						|  | "-ac", | 
					
						
						|  | "1", | 
					
						
						|  | "-acodec", | 
					
						
						|  | "pcm_s16le", | 
					
						
						|  | "-ar", | 
					
						
						|  | str(sr), | 
					
						
						|  | "-", | 
					
						
						|  | ] | 
					
						
						|  | out = subprocess.run(cmd, capture_output=True, check=True).stdout | 
					
						
						|  | except subprocess.CalledProcessError as e: | 
					
						
						|  | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | 
					
						
						|  | finally: | 
					
						
						|  | if isinstance(file, np.ndarray): | 
					
						
						|  | os.remove(temp_file_path) | 
					
						
						|  |  | 
					
						
						|  | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): | 
					
						
						|  | """ | 
					
						
						|  | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. | 
					
						
						|  | """ | 
					
						
						|  | if torch.is_tensor(array): | 
					
						
						|  | if array.shape[axis] > length: | 
					
						
						|  | array = array.index_select( | 
					
						
						|  | dim=axis, index=torch.arange(length, device=array.device) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if array.shape[axis] < length: | 
					
						
						|  | pad_widths = [(0, 0)] * array.ndim | 
					
						
						|  | pad_widths[axis] = (0, length - array.shape[axis]) | 
					
						
						|  | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) | 
					
						
						|  | else: | 
					
						
						|  | if array.shape[axis] > length: | 
					
						
						|  | array = array.take(indices=range(length), axis=axis) | 
					
						
						|  |  | 
					
						
						|  | if array.shape[axis] < length: | 
					
						
						|  | pad_widths = [(0, 0)] * array.ndim | 
					
						
						|  | pad_widths[axis] = (0, length - array.shape[axis]) | 
					
						
						|  | array = np.pad(array, pad_widths) | 
					
						
						|  |  | 
					
						
						|  | return array | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @lru_cache(maxsize=None) | 
					
						
						|  | def mel_filters(device, n_mels: int) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. | 
					
						
						|  | Allows decoupling librosa dependency; saved using: | 
					
						
						|  |  | 
					
						
						|  | np.savez_compressed( | 
					
						
						|  | "mel_filters.npz", | 
					
						
						|  | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), | 
					
						
						|  | ) | 
					
						
						|  | """ | 
					
						
						|  | assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}" | 
					
						
						|  | with np.load( | 
					
						
						|  | os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") | 
					
						
						|  | ) as f: | 
					
						
						|  | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def log_mel_spectrogram( | 
					
						
						|  | audio: Union[str, np.ndarray, torch.Tensor], | 
					
						
						|  | n_mels: int, | 
					
						
						|  | padding: int = 0, | 
					
						
						|  | device: Optional[Union[str, torch.device]] = None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Compute the log-Mel spectrogram of | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) | 
					
						
						|  | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz | 
					
						
						|  |  | 
					
						
						|  | n_mels: int | 
					
						
						|  | The number of Mel-frequency filters, only 80 is supported | 
					
						
						|  |  | 
					
						
						|  | padding: int | 
					
						
						|  | Number of zero samples to pad to the right | 
					
						
						|  |  | 
					
						
						|  | device: Optional[Union[str, torch.device]] | 
					
						
						|  | If given, the audio tensor is moved to this device before STFT | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | torch.Tensor, shape = (80, n_frames) | 
					
						
						|  | A Tensor that contains the Mel spectrogram | 
					
						
						|  | """ | 
					
						
						|  | if not torch.is_tensor(audio): | 
					
						
						|  | if isinstance(audio, str): | 
					
						
						|  | audio = load_audio(audio) | 
					
						
						|  | audio = torch.from_numpy(audio) | 
					
						
						|  |  | 
					
						
						|  | if device is not None: | 
					
						
						|  | audio = audio.to(device) | 
					
						
						|  | if padding > 0: | 
					
						
						|  | audio = F.pad(audio, (0, padding)) | 
					
						
						|  | window = torch.hann_window(N_FFT).to(audio.device) | 
					
						
						|  | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) | 
					
						
						|  | magnitudes = stft[..., :-1].abs() ** 2 | 
					
						
						|  |  | 
					
						
						|  | filters = mel_filters(audio.device, n_mels) | 
					
						
						|  | mel_spec = filters @ magnitudes | 
					
						
						|  |  | 
					
						
						|  | log_spec = torch.clamp(mel_spec, min=1e-10).log10() | 
					
						
						|  | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | 
					
						
						|  | log_spec = (log_spec + 4.0) / 4.0 | 
					
						
						|  | return log_spec |