import sys,os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import numpy as np import argparse import torch import random from tqdm import tqdm from whisper.model import Whisper, ModelDimensions from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram def load_model(path) -> Whisper: device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint = torch.load(path, map_location="cpu") dims = ModelDimensions(**checkpoint["dims"]) print(dims) model = Whisper(dims) del model.decoder cut = len(model.encoder.blocks) // 4 cut = -1 * cut del model.encoder.blocks[cut:] model.load_state_dict(checkpoint["model_state_dict"], strict=False) model.eval() model.half() model.to(device) return model def pred_ppg(whisper: Whisper, wavPath, ppgPath): audio = load_audio(wavPath) audln = audio.shape[0] ppgln = audln // 320 audio = pad_or_trim(audio) mel = log_mel_spectrogram(audio).half().to(whisper.device) with torch.no_grad(): ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() ppg = ppg[:ppgln,] # [length, dim=1280] # print(ppg.shape) np.save(ppgPath, ppg, allow_pickle=False) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) parser.add_argument("-p", "--ppg", help="ppg", dest="ppg", required=True) args = parser.parse_args() print(args.wav) print(args.ppg) os.makedirs(args.ppg, exist_ok=True) wavPath = args.wav ppgPath = args.ppg whisper = load_model(os.path.join("whisper_pretrain", "large-v2.pt")) spkPaths = os.listdir(wavPath) random.shuffle(spkPaths) for spks in spkPaths: if os.path.isdir(f"./{wavPath}/{spks}"): os.makedirs(f"./{ppgPath}/{spks}", exist_ok=True) files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] for file in tqdm(files, desc=f'Processing ppg {spks}'): if file.endswith(".wav"): # print(file) file = file[:-4] path_wav = f"{wavPath}/{spks}/{file}.wav" path_ppg = f"{ppgPath}/{spks}/{file}.ppg" if os.path.isfile(f"{path_ppg}.npy"): continue pred_ppg(whisper, path_wav, path_ppg)