Spaces:
Paused
Paused
| # Copyright (c) 2024 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| import os | |
| import sys | |
| # to import modules from parent_dir | |
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.append(parent_dir) | |
| import torch | |
| import json | |
| from env import AttrDict | |
| from bigvgan import BigVGAN | |
| from time import time | |
| from tqdm import tqdm | |
| from meldataset import mel_spectrogram, MAX_WAV_VALUE | |
| from scipy.io.wavfile import write | |
| import numpy as np | |
| import argparse | |
| torch.backends.cudnn.benchmark = True | |
| # For easier debugging | |
| torch.set_printoptions(linewidth=200, threshold=10_000) | |
| def generate_soundwave(duration=5.0, sr=24000): | |
| t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32) | |
| modulation = np.sin(2 * np.pi * t / duration) | |
| min_freq = 220 | |
| max_freq = 1760 | |
| frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2 | |
| soundwave = np.sin(2 * np.pi * frequencies * t) | |
| soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95 | |
| return soundwave, sr | |
| def get_mel(x, h): | |
| return mel_spectrogram( | |
| x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax | |
| ) | |
| def load_checkpoint(filepath, device): | |
| assert os.path.isfile(filepath) | |
| print(f"Loading '{filepath}'") | |
| checkpoint_dict = torch.load(filepath, map_location=device) | |
| print("Complete.") | |
| return checkpoint_dict | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Test script to check CUDA kernel correctness." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_file", | |
| type=str, | |
| required=True, | |
| help="Path to the checkpoint file. Assumes config.json exists in the directory.", | |
| ) | |
| args = parser.parse_args() | |
| config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json") | |
| with open(config_file) as f: | |
| config = f.read() | |
| json_config = json.loads(config) | |
| h = AttrDict({**json_config}) | |
| print("loading plain Pytorch BigVGAN") | |
| generator_original = BigVGAN(h).to("cuda") | |
| print("loading CUDA kernel BigVGAN with auto-build") | |
| generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda") | |
| state_dict_g = load_checkpoint(args.checkpoint_file, "cuda") | |
| generator_original.load_state_dict(state_dict_g["generator"]) | |
| generator_cuda_kernel.load_state_dict(state_dict_g["generator"]) | |
| generator_original.remove_weight_norm() | |
| generator_original.eval() | |
| generator_cuda_kernel.remove_weight_norm() | |
| generator_cuda_kernel.eval() | |
| # define number of samples and length of mel frame to benchmark | |
| num_sample = 10 | |
| num_mel_frame = 16384 | |
| # CUDA kernel correctness check | |
| diff = 0.0 | |
| for i in tqdm(range(num_sample)): | |
| # Random mel | |
| data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") | |
| with torch.inference_mode(): | |
| audio_original = generator_original(data) | |
| with torch.inference_mode(): | |
| audio_cuda_kernel = generator_cuda_kernel(data) | |
| # Both outputs should be (almost) the same | |
| test_result = (audio_original - audio_cuda_kernel).abs() | |
| diff += test_result.mean(dim=-1).item() | |
| diff /= num_sample | |
| if ( | |
| diff <= 2e-3 | |
| ): # We can expect a small difference (~1e-3) which does not affect perceptual quality | |
| print( | |
| f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference" | |
| f"\n > mean_difference={diff}" | |
| f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}" | |
| f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" | |
| ) | |
| else: | |
| print( | |
| f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference" | |
| f"\n > mean_difference={diff}" | |
| f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, " | |
| f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" | |
| ) | |
| del data, audio_original, audio_cuda_kernel | |
| # Variables for tracking total time and VRAM usage | |
| toc_total_original = 0 | |
| toc_total_cuda_kernel = 0 | |
| vram_used_original_total = 0 | |
| vram_used_cuda_kernel_total = 0 | |
| audio_length_total = 0 | |
| # Measure Original inference in isolation | |
| for i in tqdm(range(num_sample)): | |
| torch.cuda.reset_peak_memory_stats(device="cuda") | |
| data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") | |
| torch.cuda.synchronize() | |
| tic = time() | |
| with torch.inference_mode(): | |
| audio_original = generator_original(data) | |
| torch.cuda.synchronize() | |
| toc = time() - tic | |
| toc_total_original += toc | |
| vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda") | |
| del data, audio_original | |
| torch.cuda.empty_cache() | |
| # Measure CUDA kernel inference in isolation | |
| for i in tqdm(range(num_sample)): | |
| torch.cuda.reset_peak_memory_stats(device="cuda") | |
| data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") | |
| torch.cuda.synchronize() | |
| tic = time() | |
| with torch.inference_mode(): | |
| audio_cuda_kernel = generator_cuda_kernel(data) | |
| torch.cuda.synchronize() | |
| toc = time() - tic | |
| toc_total_cuda_kernel += toc | |
| audio_length_total += audio_cuda_kernel.shape[-1] | |
| vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda") | |
| del data, audio_cuda_kernel | |
| torch.cuda.empty_cache() | |
| # Calculate metrics | |
| audio_second = audio_length_total / h.sampling_rate | |
| khz_original = audio_length_total / toc_total_original / 1000 | |
| khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000 | |
| vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3) | |
| vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3) | |
| # Print results | |
| print( | |
| f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB" | |
| ) | |
| print( | |
| f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB" | |
| ) | |
| print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}") | |
| print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}") | |
| # Use artificial sine waves for inference test | |
| audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate) | |
| audio_real = torch.tensor(audio_real).to("cuda") | |
| # Compute mel spectrogram from the ground truth audio | |
| x = get_mel(audio_real.unsqueeze(0), h) | |
| with torch.inference_mode(): | |
| y_g_hat_original = generator_original(x) | |
| y_g_hat_cuda_kernel = generator_cuda_kernel(x) | |
| audio_real = audio_real.squeeze() | |
| audio_real = audio_real * MAX_WAV_VALUE | |
| audio_real = audio_real.cpu().numpy().astype("int16") | |
| audio_original = y_g_hat_original.squeeze() | |
| audio_original = audio_original * MAX_WAV_VALUE | |
| audio_original = audio_original.cpu().numpy().astype("int16") | |
| audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze() | |
| audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE | |
| audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16") | |
| os.makedirs("tmp", exist_ok=True) | |
| output_file_real = os.path.join("tmp", "audio_real.wav") | |
| output_file_original = os.path.join("tmp", "audio_generated_original.wav") | |
| output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav") | |
| write(output_file_real, h.sampling_rate, audio_real) | |
| write(output_file_original, h.sampling_rate, audio_original) | |
| write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel) | |
| print("Example generated audios of original vs. fused CUDA kernel written to tmp!") | |
| print("Done") | |