audio-model / test.py
Tharya's picture
Upload 3 files
9442c34 verified
from typing import Union, Tuple
import numpy as np
from numpy.typing import NDArray
import torch
from torch import nn
from functools import partial
import matplotlib.pyplot as plt
from PIL import Image
import librosa
import miniaudio
from mae import MaskedAutoencoderViT
def load_audio(
path: str,
sr: int = 32000,
duration: int = 20,
) -> (np.ndarray, int):
g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1,
sample_rate=sr, frames_to_read=sr * duration)
signal = np.array(next(g))
return signal
def mel_spectrogram(
signal: np.ndarray,
sr: int = 32000,
n_fft: int = 800,
hop_length: int = 320,
n_mels: int = 128,
) -> np.ndarray:
mel_spec = librosa.feature.melspectrogram(
y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
window='hann', pad_mode='constant'
)
mel_spec = librosa.power_to_db(mel_spec) # (freq, time)
return mel_spec.T # (time, freq)
def display_image(
img: Union[NDArray, Image.Image],
figsize: Tuple[float, float] = (5, 5),
) -> None:
plt.figure(figsize=figsize)
plt.imshow(img, origin='lower', aspect='auto') # cmp = 'viridis', 'coolwarm'
plt.axis('off')
plt.colorbar()
plt.tight_layout()
plt.show()
def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray:
return (arr - arr.mean()) / (arr.std() + eps)
if __name__ == '__main__':
mp3_file = "/Users/chenjing22/Downloads/songs/See You Again.mp3"
mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21)) # (time, freq)
# padding or truncating
length = mel_spec.shape[0]
target_length = 2048
mel_spec = mel_spec[:target_length] if length > target_length else np.pad(
mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min()
)
# normalize
mel_spec = normalize(mel_spec) # (2048, 128)
display_image(mel_spec.T, figsize=(10, 4))
# Model
mae = MaskedAutoencoderViT(
img_size=(2048, 128),
patch_size=16,
in_chans=1,
embed_dim=768,
depth=12,
num_heads=12,
decoder_mode=1,
no_shift=False,
decoder_embed_dim=512,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
norm_pix_loss=False,
pos_trainable=False,
)
# Load pre-trained weights
ckpt_path = 'music-mae-32kHz.pth'
mae.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
device = 'cpu' # 'cuda'
mae.to(device)
x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 2048, 128)
mse_loss, y, mask = mae(x, mask_ratio=0.7) # y: (1, 1024, 256), mask: (1, 1024)
y[mask == 0.] = mae.patchify(x)[mask == 0.]
x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy()
print(f'mse_loss: {mse_loss.item()}')
display_image(x_reconstructed.T, figsize=(10, 4))