Tacotron2_GST_eng / melgan /inference.py
AlexK-PL's picture
Upload MelGAN model
85e1f14
raw
history blame
1.64 kB
import os
import glob
import tqdm
import torch
import argparse
from scipy.io.wavfile import write
from model.generator import Generator
from utils.hparams import HParam, load_hparam_str
MAX_WAV_VALUE = 32768.0
def main(args):
checkpoint = torch.load(args.checkpoint_path)
if args.config is not None:
hp = HParam(args.config)
else:
hp = load_hparam_str(checkpoint['hp_str'])
model = Generator(hp.audio.n_mel_channels).cuda()
model.load_state_dict(checkpoint['model_g'])
model.eval(inference=False)
with torch.no_grad():
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
mel = torch.load(melpath)
if len(mel.shape) == 2:
mel = mel.unsqueeze(0)
mel = mel.cuda()
audio = model.inference(mel)
audio = audio.cpu().detach().numpy()
out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
write(out_path, hp.audio.sampling_rate, audio)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default=None,
help="yaml file for config. will use hp_str from checkpoint if not given.")
parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
help="path of checkpoint pt file for evaluation")
parser.add_argument('-i', '--input_folder', type=str, required=True,
help="directory of mel-spectrograms to invert into raw audio. ")
args = parser.parse_args()
main(args)