Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import librosa | |
| import torch | |
| import numpy as np | |
| def extract_mstft( | |
| audio_ref, | |
| audio_deg, | |
| fs=None, | |
| mid_freq=None, | |
| high_freq=None, | |
| method="cut", | |
| version="pwg", | |
| ): | |
| """Compute Multi-Scale STFT Distance (mstft) between the predicted and the ground truth audio. | |
| audio_ref: path to the ground truth audio. | |
| audio_deg: path to the predicted audio. | |
| fs: sampling rate. | |
| med_freq: division frequency for mid frequency parts. | |
| high_freq: division frequency for high frequency parts. | |
| method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio. | |
| "cut" will cut both audios into a same length according to the one with the shorter length. | |
| version: "pwg" will use the computational version provided by ParallelWaveGAN. | |
| "encodec" will use the computational version provided by Encodec. | |
| """ | |
| # Load audio | |
| if fs != None: | |
| audio_ref, _ = librosa.load(audio_ref, sr=fs) | |
| audio_deg, _ = librosa.load(audio_deg, sr=fs) | |
| else: | |
| audio_ref, fs = librosa.load(audio_ref) | |
| audio_deg, fs = librosa.load(audio_deg) | |
| # Automatically choose mid_freq and high_freq if they are not given | |
| if mid_freq == None: | |
| mid_freq = fs // 6 | |
| if high_freq == None: | |
| high_freq = fs // 3 | |
| # Audio length alignment | |
| if len(audio_ref) != len(audio_deg): | |
| if method == "cut": | |
| length = min(len(audio_ref), len(audio_deg)) | |
| audio_ref = audio_ref[:length] | |
| audio_deg = audio_deg[:length] | |
| elif method == "dtw": | |
| _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True) | |
| audio_ref_new = [] | |
| audio_deg_new = [] | |
| for i in range(wp.shape[0]): | |
| ref_index = wp[i][0] | |
| deg_index = wp[i][1] | |
| audio_ref_new.append(audio_ref[ref_index]) | |
| audio_deg_new.append(audio_deg[deg_index]) | |
| audio_ref = np.array(audio_ref_new) | |
| audio_deg = np.array(audio_deg_new) | |
| assert len(audio_ref) == len(audio_deg) | |
| # Define loss function | |
| l1Loss = torch.nn.L1Loss(reduction="mean") | |
| l2Loss = torch.nn.MSELoss(reduction="mean") | |
| # Compute distance | |
| if version == "encodec": | |
| n_fft = 1024 | |
| mstft = 0 | |
| mstft_low = 0 | |
| mstft_mid = 0 | |
| mstft_high = 0 | |
| freq_resolution = fs / n_fft | |
| mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution)) | |
| high_freq_index = 1 + int(np.floor(high_freq / freq_resolution)) | |
| for i in range(5, 11): | |
| hop_length = 2**i // 4 | |
| win_length = 2**i | |
| spec_ref = librosa.stft( | |
| y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length | |
| ) | |
| spec_deg = librosa.stft( | |
| y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length | |
| ) | |
| mag_ref = np.abs(spec_ref) | |
| mag_deg = np.abs(spec_deg) | |
| mag_ref = torch.from_numpy(mag_ref) | |
| mag_deg = torch.from_numpy(mag_deg) | |
| mstft += l1Loss(mag_ref, mag_deg) + l2Loss(mag_ref, mag_deg) | |
| mag_ref_low = mag_ref[:mid_freq_index, :] | |
| mag_deg_low = mag_deg[:mid_freq_index, :] | |
| mstft_low += l1Loss(mag_ref_low, mag_deg_low) + l2Loss( | |
| mag_ref_low, mag_deg_low | |
| ) | |
| mag_ref_mid = mag_ref[mid_freq_index:high_freq_index, :] | |
| mag_deg_mid = mag_deg[mid_freq_index:high_freq_index, :] | |
| mstft_mid += l1Loss(mag_ref_mid, mag_deg_mid) + l2Loss( | |
| mag_ref_mid, mag_deg_mid | |
| ) | |
| mag_ref_high = mag_ref[high_freq_index:, :] | |
| mag_deg_high = mag_deg[high_freq_index:, :] | |
| mstft_high += l1Loss(mag_ref_high, mag_deg_high) + l2Loss( | |
| mag_ref_high, mag_deg_high | |
| ) | |
| mstft /= 6 | |
| mstft_low /= 6 | |
| mstft_mid /= 6 | |
| mstft_high /= 6 | |
| return mstft | |
| elif version == "pwg": | |
| fft_sizes = [1024, 2048, 512] | |
| hop_sizes = [120, 240, 50] | |
| win_sizes = [600, 1200, 240] | |
| audio_ref = torch.from_numpy(audio_ref) | |
| audio_deg = torch.from_numpy(audio_deg) | |
| mstft_sc = 0 | |
| mstft_sc_low = 0 | |
| mstft_sc_mid = 0 | |
| mstft_sc_high = 0 | |
| mstft_mag = 0 | |
| mstft_mag_low = 0 | |
| mstft_mag_mid = 0 | |
| mstft_mag_high = 0 | |
| for n_fft, hop_length, win_length in zip(fft_sizes, hop_sizes, win_sizes): | |
| spec_ref = torch.stft( | |
| audio_ref, n_fft, hop_length, win_length, return_complex=False | |
| ) | |
| spec_deg = torch.stft( | |
| audio_deg, n_fft, hop_length, win_length, return_complex=False | |
| ) | |
| real_ref = spec_ref[..., 0] | |
| imag_ref = spec_ref[..., 1] | |
| real_deg = spec_deg[..., 0] | |
| imag_deg = spec_deg[..., 1] | |
| mag_ref = torch.sqrt( | |
| torch.clamp(real_ref**2 + imag_ref**2, min=1e-7) | |
| ).transpose(1, 0) | |
| mag_deg = torch.sqrt( | |
| torch.clamp(real_deg**2 + imag_deg**2, min=1e-7) | |
| ).transpose(1, 0) | |
| sc_loss = torch.norm(mag_ref - mag_deg, p="fro") / torch.norm( | |
| mag_ref, p="fro" | |
| ) | |
| mag_loss = l1Loss(torch.log(mag_ref), torch.log(mag_deg)) | |
| mstft_sc += sc_loss | |
| mstft_mag += mag_loss | |
| freq_resolution = fs / n_fft | |
| mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution)) | |
| high_freq_index = 1 + int(np.floor(high_freq / freq_resolution)) | |
| mag_ref_low = mag_ref[:, :mid_freq_index] | |
| mag_deg_low = mag_deg[:, :mid_freq_index] | |
| sc_loss_low = torch.norm(mag_ref_low - mag_deg_low, p="fro") / torch.norm( | |
| mag_ref_low, p="fro" | |
| ) | |
| mag_loss_low = l1Loss(torch.log(mag_ref_low), torch.log(mag_deg_low)) | |
| mstft_sc_low += sc_loss_low | |
| mstft_mag_low += mag_loss_low | |
| mag_ref_mid = mag_ref[:, mid_freq_index:high_freq_index] | |
| mag_deg_mid = mag_deg[:, mid_freq_index:high_freq_index] | |
| sc_loss_mid = torch.norm(mag_ref_mid - mag_deg_mid, p="fro") / torch.norm( | |
| mag_ref_mid, p="fro" | |
| ) | |
| mag_loss_mid = l1Loss(torch.log(mag_ref_mid), torch.log(mag_deg_mid)) | |
| mstft_sc_mid += sc_loss_mid | |
| mstft_mag_mid += mag_loss_mid | |
| mag_ref_high = mag_ref[:, high_freq_index:] | |
| mag_deg_high = mag_deg[:, high_freq_index:] | |
| sc_loss_high = torch.norm( | |
| mag_ref_high - mag_deg_high, p="fro" | |
| ) / torch.norm(mag_ref_high, p="fro") | |
| mag_loss_high = l1Loss(torch.log(mag_ref_high), torch.log(mag_deg_high)) | |
| mstft_sc_high += sc_loss_high | |
| mstft_mag_high += mag_loss_high | |
| # Normalize distances | |
| mstft_sc /= len(fft_sizes) | |
| mstft_sc_low /= len(fft_sizes) | |
| mstft_sc_mid /= len(fft_sizes) | |
| mstft_sc_high /= len(fft_sizes) | |
| mstft_mag /= len(fft_sizes) | |
| mstft_mag_low /= len(fft_sizes) | |
| mstft_mag_mid /= len(fft_sizes) | |
| mstft_mag_high /= len(fft_sizes) | |
| # return ( | |
| # mstft_sc.numpy().tolist(), | |
| # mstft_sc_low.numpy().tolist(), | |
| # mstft_sc_mid.numpy().tolist(), | |
| # mstft_sc_high.numpy().tolist(), | |
| # mstft_mag.numpy().tolist(), | |
| # mstft_mag_low.numpy().tolist(), | |
| # mstft_mag_mid.numpy().tolist(), | |
| # mstft_mag_high.numpy().tolist(), | |
| # ) | |
| return mstft_sc.numpy().tolist() + mstft_mag.numpy().tolist() | |