from typing import Union, Tuple import numpy as np from numpy.typing import NDArray import torch from torch import nn from functools import partial import matplotlib.pyplot as plt from PIL import Image import librosa import miniaudio from mae import MaskedAutoencoderViT def load_audio( path: str, sr: int = 32000, duration: int = 20, ) -> (np.ndarray, int): g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1, sample_rate=sr, frames_to_read=sr * duration) signal = np.array(next(g)) return signal def mel_spectrogram( signal: np.ndarray, sr: int = 32000, n_fft: int = 800, hop_length: int = 320, n_mels: int = 128, ) -> np.ndarray: mel_spec = librosa.feature.melspectrogram( y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, window='hann', pad_mode='constant' ) mel_spec = librosa.power_to_db(mel_spec) # (freq, time) return mel_spec.T # (time, freq) def display_image( img: Union[NDArray, Image.Image], figsize: Tuple[float, float] = (5, 5), ) -> None: plt.figure(figsize=figsize) plt.imshow(img, origin='lower', aspect='auto') # cmp = 'viridis', 'coolwarm' plt.axis('off') plt.colorbar() plt.tight_layout() plt.show() def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray: return (arr - arr.mean()) / (arr.std() + eps) if __name__ == '__main__': mp3_file = "/Users/chenjing22/Downloads/songs/See You Again.mp3" mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21)) # (time, freq) # padding or truncating length = mel_spec.shape[0] target_length = 2048 mel_spec = mel_spec[:target_length] if length > target_length else np.pad( mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min() ) # normalize mel_spec = normalize(mel_spec) # (2048, 128) display_image(mel_spec.T, figsize=(10, 4)) # Model mae = MaskedAutoencoderViT( img_size=(2048, 128), patch_size=16, in_chans=1, embed_dim=768, depth=12, num_heads=12, decoder_mode=1, no_shift=False, decoder_embed_dim=512, norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, pos_trainable=False, ) # Load pre-trained weights ckpt_path = 'music-mae-32kHz.pth' mae.load_state_dict(torch.load(ckpt_path, map_location='cpu')) device = 'cpu' # 'cuda' mae.to(device) x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 2048, 128) mse_loss, y, mask = mae(x, mask_ratio=0.7) # y: (1, 1024, 256), mask: (1, 1024) y[mask == 0.] = mae.patchify(x)[mask == 0.] x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy() print(f'mse_loss: {mse_loss.item()}') display_image(x_reconstructed.T, figsize=(10, 4))