|
from typing import Union, Tuple |
|
import numpy as np |
|
from numpy.typing import NDArray |
|
import torch |
|
from torch import nn |
|
from functools import partial |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import librosa |
|
import miniaudio |
|
|
|
from mae import MaskedAutoencoderViT |
|
|
|
|
|
def load_audio( |
|
path: str, |
|
sr: int = 32000, |
|
duration: int = 20, |
|
) -> (np.ndarray, int): |
|
g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1, |
|
sample_rate=sr, frames_to_read=sr * duration) |
|
signal = np.array(next(g)) |
|
return signal |
|
|
|
|
|
def mel_spectrogram( |
|
signal: np.ndarray, |
|
sr: int = 32000, |
|
n_fft: int = 800, |
|
hop_length: int = 320, |
|
n_mels: int = 128, |
|
) -> np.ndarray: |
|
mel_spec = librosa.feature.melspectrogram( |
|
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, |
|
window='hann', pad_mode='constant' |
|
) |
|
mel_spec = librosa.power_to_db(mel_spec) |
|
return mel_spec.T |
|
|
|
|
|
def display_image( |
|
img: Union[NDArray, Image.Image], |
|
figsize: Tuple[float, float] = (5, 5), |
|
) -> None: |
|
plt.figure(figsize=figsize) |
|
plt.imshow(img, origin='lower', aspect='auto') |
|
plt.axis('off') |
|
plt.colorbar() |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray: |
|
return (arr - arr.mean()) / (arr.std() + eps) |
|
|
|
|
|
if __name__ == '__main__': |
|
mp3_file = "/Users/chenjing22/Downloads/songs/See You Again.mp3" |
|
mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21)) |
|
|
|
|
|
length = mel_spec.shape[0] |
|
target_length = 2048 |
|
mel_spec = mel_spec[:target_length] if length > target_length else np.pad( |
|
mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min() |
|
) |
|
|
|
|
|
mel_spec = normalize(mel_spec) |
|
|
|
display_image(mel_spec.T, figsize=(10, 4)) |
|
|
|
|
|
mae = MaskedAutoencoderViT( |
|
img_size=(2048, 128), |
|
patch_size=16, |
|
in_chans=1, |
|
embed_dim=768, |
|
depth=12, |
|
num_heads=12, |
|
decoder_mode=1, |
|
no_shift=False, |
|
decoder_embed_dim=512, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
norm_pix_loss=False, |
|
pos_trainable=False, |
|
) |
|
|
|
|
|
ckpt_path = 'music-mae-32kHz.pth' |
|
mae.load_state_dict(torch.load(ckpt_path, map_location='cpu')) |
|
|
|
device = 'cpu' |
|
mae.to(device) |
|
|
|
x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device) |
|
mse_loss, y, mask = mae(x, mask_ratio=0.7) |
|
|
|
y[mask == 0.] = mae.patchify(x)[mask == 0.] |
|
x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy() |
|
|
|
print(f'mse_loss: {mse_loss.item()}') |
|
display_image(x_reconstructed.T, figsize=(10, 4)) |
|
|