AlexK-PL commited on
Commit
85e1f14
·
1 Parent(s): 9bc975f

Upload MelGAN model

Browse files

repository cloned from:
https://github.com/seungwonpark/melgan/tree/master

melgan/.gitignore ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IDE configuration
2
+ .idea/
3
+
4
+ # configuration
5
+ config/*
6
+ !config/default.yaml
7
+ temp-restore.yaml
8
+
9
+ # logs, checkpoints
10
+ chkpt/
11
+ logs/
12
+
13
+ # just a temporary folder
14
+ temp/
15
+
16
+ # Byte-compiled / optimized / DLL files
17
+ __pycache__/
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # C extensions
22
+ *.so
23
+
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # PyInstaller
44
+ # Usually these files are written by a python script from a template
45
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
+ *.manifest
47
+ *.spec
48
+
49
+ # Installer logs
50
+ pip-log.txt
51
+ pip-delete-this-directory.txt
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .coverage
57
+ .coverage.*
58
+ .cache
59
+ nosetests.xml
60
+ coverage.xml
61
+ *.cover
62
+ .hypothesis/
63
+ .pytest_cache/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+
74
+ # Flask stuff:
75
+ instance/
76
+ .webassets-cache
77
+
78
+ # Scrapy stuff:
79
+ .scrapy
80
+
81
+ # Sphinx documentation
82
+ docs/_build/
83
+
84
+ # PyBuilder
85
+ target/
86
+
87
+ # Jupyter Notebook
88
+ .ipynb_checkpoints
89
+
90
+ # pyenv
91
+ .python-version
92
+
93
+ # celery beat schedule file
94
+ celerybeat-schedule
95
+
96
+ # SageMath parsed files
97
+ *.sage.py
98
+
99
+ # Environments
100
+ .env
101
+ .venv
102
+ env/
103
+ venv/
104
+ ENV/
105
+ env.bak/
106
+ venv.bak/
107
+
108
+ # Spyder project settings
109
+ .spyderproject
110
+ .spyproject
111
+
112
+ # Rope project settings
113
+ .ropeproject
114
+
115
+ # mkdocs documentation
116
+ /site
117
+
118
+ # mypy
119
+ .mypy_cache/
melgan/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2019, Seungwon Park 박승원
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
melgan/README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MelGAN
2
+ Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910.06711)
3
+
4
+ ## Key Features
5
+
6
+ - MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow).
7
+ - This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio.
8
+ - Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub).
9
+
10
+ ![](./assets/gd.png)
11
+
12
+ ## Prerequisites
13
+
14
+ Tested on Python 3.6
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Prepare Dataset
20
+
21
+ - Download dataset for training. This can be any wav files with sample rate 22050Hz. (e.g. LJSpeech was used in paper)
22
+ - preprocess: `python preprocess.py -c config/default.yaml -d [data's root path]`
23
+ - Edit configuration `yaml` file
24
+
25
+ ## Train & Tensorboard
26
+
27
+ - `python trainer.py -c [config yaml file] -n [name of the run]`
28
+ - `cp config/default.yaml config/config.yaml` and then edit `config.yaml`
29
+ - Write down the root path of train/validation files to 2nd/3rd line.
30
+ - Each path should contain pairs of `*.wav` with corresponding (preprocessed) `*.mel` file.
31
+ - The data loader parses list of files within the path recursively.
32
+ - `tensorboard --logdir logs/`
33
+
34
+ ## Pretrained model
35
+
36
+ Try with Google Colab: TODO
37
+
38
+ ```python
39
+ import torch
40
+ vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
41
+ vocoder.eval()
42
+ mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here
43
+
44
+ if torch.cuda.is_available():
45
+ vocoder = vocoder.cuda()
46
+ mel = mel.cuda()
47
+
48
+ with torch.no_grad():
49
+ audio = vocoder.inference(mel)
50
+ ```
51
+
52
+ ## Inference
53
+
54
+ - `python inference.py -p [checkpoint path] -i [input mel path]`
55
+
56
+ ## Results
57
+
58
+ See audio samples at: http://swpark.me/melgan/.
59
+ Model was trained at V100 GPU for 14 days using LJSpeech-1.1.
60
+
61
+ ![](./assets/lj-tensorboard-v0.3-alpha.png)
62
+
63
+
64
+ ## Implementation Authors
65
+
66
+ - [Seungwon Park](http://swpark.me) @ MINDsLab Inc. ([email protected], [email protected])
67
+ - Myunchul Joe @ MINDsLab Inc.
68
+ - [Rishikesh](https://github.com/rishikksh20) @ DeepSync Technologies Pvt Ltd.
69
+
70
+ ## License
71
+
72
+ BSD 3-Clause License.
73
+
74
+ - [utils/stft.py](./utils/stft.py) by Prem Seetharaman (BSD 3-Clause License)
75
+ - [datasets/mel2samp.py](./datasets/mel2samp.py) from https://github.com/NVIDIA/waveglow (BSD 3-Clause License)
76
+ - [utils/hparams.py](./utils/hparams.py) from https://github.com/HarryVolek/PyTorch_Speaker_Verification (No License specified)
77
+
78
+ ## Useful resources
79
+
80
+ - [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks) by Soumith Chintala
81
+ - [Official MelGAN implementation by original authors](https://github.com/descriptinc/melgan-neurips)
82
+ - [Reproduction of MelGAN - NeurIPS 2019 Reproducibility Challenge (Ablation Track)](https://openreview.net/pdf?id=9jTbNbBNw0) by Yifei Zhao, Yichao Yang, and Yang Gao
83
+ - "replacing the average pooling layer with max pooling layer and replacing reflection padding with replication padding improves the performance significantly, while combining them produces worse results"
melgan/assets/gd.png ADDED
melgan/assets/lj-tensorboard-v0.3-alpha.png ADDED
melgan/assets/lj-tensorboard.png ADDED
melgan/config/default.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data: # root path of train/validation data (either relative/absoulte path is ok)
2
+ train: ''
3
+ validation: ''
4
+ ---
5
+ train:
6
+ rep_discriminator: 1
7
+ num_workers: 32
8
+ batch_size: 16
9
+ optimizer: 'adam'
10
+ adam:
11
+ lr: 0.0001
12
+ beta1: 0.5
13
+ beta2: 0.9
14
+ ---
15
+ audio:
16
+ n_mel_channels: 80
17
+ segment_length: 16000
18
+ pad_short: 2000
19
+ filter_length: 1024
20
+ hop_length: 256 # WARNING: this can't be changed.
21
+ win_length: 1024
22
+ sampling_rate: 22050
23
+ mel_fmin: 0.0
24
+ mel_fmax: 8000.0
25
+ ---
26
+ model:
27
+ feat_match: 10.0
28
+ ---
29
+ log:
30
+ summary_interval: 1
31
+ validation_interval: 5
32
+ save_interval: 25
33
+ chkpt_dir: 'chkpt'
34
+ log_dir: 'logs'
melgan/datasets/dataloader.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from torch.utils.data import Dataset, DataLoader
7
+
8
+ from utils.utils import read_wav_np
9
+
10
+
11
+ def create_dataloader(hp, args, train):
12
+ dataset = MelFromDisk(hp, args, train)
13
+
14
+ if train:
15
+ return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True,
16
+ num_workers=hp.train.num_workers, pin_memory=True, drop_last=True)
17
+ else:
18
+ return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
19
+ num_workers=hp.train.num_workers, pin_memory=True, drop_last=False)
20
+
21
+
22
+ class MelFromDisk(Dataset):
23
+ def __init__(self, hp, args, train):
24
+ self.hp = hp
25
+ self.args = args
26
+ self.train = train
27
+ self.path = hp.data.train if train else hp.data.validation
28
+ self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True)
29
+ self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2
30
+ self.mapping = [i for i in range(len(self.wav_list))]
31
+
32
+ def __len__(self):
33
+ return len(self.wav_list)
34
+
35
+ def __getitem__(self, idx):
36
+ if self.train:
37
+ idx1 = idx
38
+ idx2 = self.mapping[idx1]
39
+ return self.my_getitem(idx1), self.my_getitem(idx2)
40
+ else:
41
+ return self.my_getitem(idx)
42
+
43
+ def shuffle_mapping(self):
44
+ random.shuffle(self.mapping)
45
+
46
+ def my_getitem(self, idx):
47
+ wavpath = self.wav_list[idx]
48
+ melpath = wavpath.replace('.wav', '.mel')
49
+ sr, audio = read_wav_np(wavpath)
50
+ if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short:
51
+ audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \
52
+ mode='constant', constant_values=0.0)
53
+
54
+ audio = torch.from_numpy(audio).unsqueeze(0)
55
+ mel = torch.load(melpath).squeeze(0)
56
+
57
+ if self.train:
58
+ max_mel_start = mel.size(1) - self.mel_segment_length
59
+ mel_start = random.randint(0, max_mel_start)
60
+ mel_end = mel_start + self.mel_segment_length
61
+ mel = mel[:, mel_start:mel_end]
62
+
63
+ audio_start = mel_start * self.hp.audio.hop_length
64
+ audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length]
65
+
66
+ audio = audio + (1/32768) * torch.randn_like(audio)
67
+ return mel, audio
melgan/hubconf.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ['torch']
2
+ import torch
3
+ from model.generator import Generator
4
+
5
+ model_params = {
6
+ 'nvidia_tacotron2_LJ11_epoch6400': {
7
+ 'mel_channel': 80,
8
+ 'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.3-alpha/nvidia_tacotron2_LJ11_epoch6400.pt',
9
+ },
10
+ }
11
+
12
+
13
+ def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True):
14
+ params = model_params[model_name]
15
+ model = Generator(params['mel_channel'])
16
+
17
+ if pretrained:
18
+ state_dict = torch.hub.load_state_dict_from_url(params['model_url'],
19
+ progress=progress)
20
+ model.load_state_dict(state_dict['model_g'])
21
+
22
+ model.eval(inference=True)
23
+
24
+ return model
25
+
26
+
27
+ if __name__ == '__main__':
28
+ vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
29
+ mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here
30
+
31
+ print('Input mel-spectrogram shape: {}'.format(mel.shape))
32
+
33
+ if torch.cuda.is_available():
34
+ print('Moving data & model to GPU')
35
+ vocoder = vocoder.cuda()
36
+ mel = mel.cuda()
37
+
38
+ with torch.no_grad():
39
+ audio = vocoder.inference(mel)
40
+
41
+ print('Output audio shape: {}'.format(audio.shape))
melgan/inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import torch
5
+ import argparse
6
+ from scipy.io.wavfile import write
7
+
8
+ from model.generator import Generator
9
+ from utils.hparams import HParam, load_hparam_str
10
+
11
+ MAX_WAV_VALUE = 32768.0
12
+
13
+
14
+ def main(args):
15
+ checkpoint = torch.load(args.checkpoint_path)
16
+ if args.config is not None:
17
+ hp = HParam(args.config)
18
+ else:
19
+ hp = load_hparam_str(checkpoint['hp_str'])
20
+
21
+ model = Generator(hp.audio.n_mel_channels).cuda()
22
+ model.load_state_dict(checkpoint['model_g'])
23
+ model.eval(inference=False)
24
+
25
+ with torch.no_grad():
26
+ for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
27
+ mel = torch.load(melpath)
28
+ if len(mel.shape) == 2:
29
+ mel = mel.unsqueeze(0)
30
+ mel = mel.cuda()
31
+
32
+ audio = model.inference(mel)
33
+ audio = audio.cpu().detach().numpy()
34
+
35
+ out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
36
+ write(out_path, hp.audio.sampling_rate, audio)
37
+
38
+
39
+ if __name__ == '__main__':
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument('-c', '--config', type=str, default=None,
42
+ help="yaml file for config. will use hp_str from checkpoint if not given.")
43
+ parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
44
+ help="path of checkpoint pt file for evaluation")
45
+ parser.add_argument('-i', '--input_folder', type=str, required=True,
46
+ help="directory of mel-spectrograms to invert into raw audio. ")
47
+ args = parser.parse_args()
48
+
49
+ main(args)
melgan/model/discriminator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Discriminator(nn.Module):
7
+ def __init__(self):
8
+ super(Discriminator, self).__init__()
9
+
10
+ self.discriminator = nn.ModuleList([
11
+ nn.Sequential(
12
+ nn.ReflectionPad1d(7),
13
+ nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1)),
14
+ nn.LeakyReLU(0.2, inplace=True),
15
+ ),
16
+ nn.Sequential(
17
+ nn.utils.weight_norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4)),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ ),
20
+ nn.Sequential(
21
+ nn.utils.weight_norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16)),
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ ),
24
+ nn.Sequential(
25
+ nn.utils.weight_norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64)),
26
+ nn.LeakyReLU(0.2, inplace=True),
27
+ ),
28
+ nn.Sequential(
29
+ nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256)),
30
+ nn.LeakyReLU(0.2, inplace=True),
31
+ ),
32
+ nn.Sequential(
33
+ nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=2)),
34
+ nn.LeakyReLU(0.2, inplace=True),
35
+ ),
36
+ nn.utils.weight_norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1)),
37
+ ])
38
+
39
+ def forward(self, x):
40
+ '''
41
+ returns: (list of 6 features, discriminator score)
42
+ we directly predict score without last sigmoid function
43
+ since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076)
44
+ '''
45
+ features = list()
46
+ for module in self.discriminator:
47
+ x = module(x)
48
+ features.append(x)
49
+ return features[:-1], features[-1]
50
+
51
+
52
+ if __name__ == '__main__':
53
+ model = Discriminator()
54
+
55
+ x = torch.randn(3, 1, 22050)
56
+ print(x.shape)
57
+
58
+ features, score = model(x)
59
+ for feat in features:
60
+ print(feat.shape)
61
+ print(score.shape)
62
+
63
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
64
+ print(pytorch_total_params)
melgan/model/generator.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .res_stack import ResStack
6
+ # from res_stack import ResStack
7
+
8
+ MAX_WAV_VALUE = 32768.0
9
+
10
+
11
+ class Generator(nn.Module):
12
+ def __init__(self, mel_channel):
13
+ super(Generator, self).__init__()
14
+ self.mel_channel = mel_channel
15
+
16
+ self.generator = nn.Sequential(
17
+ nn.ReflectionPad1d(3),
18
+ nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)),
19
+
20
+ nn.LeakyReLU(0.2),
21
+ nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)),
22
+
23
+ ResStack(256),
24
+
25
+ nn.LeakyReLU(0.2),
26
+ nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)),
27
+
28
+ ResStack(128),
29
+
30
+ nn.LeakyReLU(0.2),
31
+ nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)),
32
+
33
+ ResStack(64),
34
+
35
+ nn.LeakyReLU(0.2),
36
+ nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)),
37
+
38
+ ResStack(32),
39
+
40
+ nn.LeakyReLU(0.2),
41
+ nn.ReflectionPad1d(3),
42
+ nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)),
43
+ nn.Tanh(),
44
+ )
45
+
46
+ def forward(self, mel):
47
+ mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
48
+ return self.generator(mel)
49
+
50
+ def eval(self, inference=False):
51
+ super(Generator, self).eval()
52
+
53
+ # don't remove weight norm while validation in training loop
54
+ if inference:
55
+ self.remove_weight_norm()
56
+
57
+ def remove_weight_norm(self):
58
+ for idx, layer in enumerate(self.generator):
59
+ if len(layer.state_dict()) != 0:
60
+ try:
61
+ nn.utils.remove_weight_norm(layer)
62
+ except:
63
+ layer.remove_weight_norm()
64
+
65
+ def inference(self, mel):
66
+ hop_length = 256
67
+ # pad input mel with zeros to cut artifact
68
+ # see https://github.com/seungwonpark/melgan/issues/8
69
+ zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
70
+ mel = torch.cat((mel, zero), dim=2)
71
+
72
+ audio = self.forward(mel)
73
+ audio = audio.squeeze() # collapse all dimension except time axis
74
+ audio = audio[:-(hop_length*10)]
75
+ audio = MAX_WAV_VALUE * audio
76
+ audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
77
+ audio = audio.short()
78
+
79
+ return audio
80
+
81
+
82
+ '''
83
+ to run this, fix
84
+ from . import ResStack
85
+ into
86
+ from res_stack import ResStack
87
+ '''
88
+ if __name__ == '__main__':
89
+ model = Generator(80)
90
+
91
+ x = torch.randn(3, 80, 10)
92
+ print(x.shape)
93
+
94
+ y = model(x)
95
+ print(y.shape)
96
+ assert y.shape == torch.Size([3, 1, 2560])
97
+
98
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
99
+ print(pytorch_total_params)
melgan/model/identity.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Identity(nn.Module):
7
+ def __init__(self):
8
+ super(Identity, self).__init__()
9
+
10
+ def forward(self, x):
11
+ return x
12
+
melgan/model/multiscale.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .discriminator import Discriminator
6
+ from .identity import Identity
7
+
8
+
9
+ class MultiScaleDiscriminator(nn.Module):
10
+ def __init__(self):
11
+ super(MultiScaleDiscriminator, self).__init__()
12
+
13
+ self.discriminators = nn.ModuleList(
14
+ [Discriminator() for _ in range(3)]
15
+ )
16
+
17
+ self.pooling = nn.ModuleList(
18
+ [Identity()] +
19
+ [nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)]
20
+ )
21
+
22
+ def forward(self, x):
23
+ ret = list()
24
+
25
+ for pool, disc in zip(self.pooling, self.discriminators):
26
+ x = pool(x)
27
+ ret.append(disc(x))
28
+
29
+ return ret # [(feat, score), (feat, score), (feat, score)]
melgan/model/res_stack.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class ResStack(nn.Module):
8
+ def __init__(self, channel):
9
+ super(ResStack, self).__init__()
10
+
11
+ self.blocks = nn.ModuleList([
12
+ nn.Sequential(
13
+ nn.LeakyReLU(0.2),
14
+ nn.ReflectionPad1d(3**i),
15
+ nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)),
16
+ nn.LeakyReLU(0.2),
17
+ nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
18
+ )
19
+ for i in range(3)
20
+ ])
21
+
22
+ self.shortcuts = nn.ModuleList([
23
+ nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
24
+ for i in range(3)
25
+ ])
26
+
27
+ def forward(self, x):
28
+ for block, shortcut in zip(self.blocks, self.shortcuts):
29
+ x = shortcut(x) + block(x)
30
+ return x
31
+
32
+ def remove_weight_norm(self):
33
+ for block, shortcut in zip(self.blocks, self.shortcuts):
34
+ nn.utils.remove_weight_norm(block[2])
35
+ nn.utils.remove_weight_norm(block[4])
36
+ nn.utils.remove_weight_norm(shortcut)
melgan/preprocess.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import tqdm
4
+ import torch
5
+ import argparse
6
+ import numpy as np
7
+
8
+ from utils.stft import TacotronSTFT
9
+ from utils.hparams import HParam
10
+ from utils.utils import read_wav_np
11
+
12
+
13
+ def main(hp, args):
14
+ stft = TacotronSTFT(filter_length=hp.audio.filter_length,
15
+ hop_length=hp.audio.hop_length,
16
+ win_length=hp.audio.win_length,
17
+ n_mel_channels=hp.audio.n_mel_channels,
18
+ sampling_rate=hp.audio.sampling_rate,
19
+ mel_fmin=hp.audio.mel_fmin,
20
+ mel_fmax=hp.audio.mel_fmax)
21
+
22
+ wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
23
+
24
+ for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
25
+ sr, wav = read_wav_np(wavpath)
26
+ assert sr == hp.audio.sampling_rate, \
27
+ "sample rate mismatch. expected %d, got %d at %s" % \
28
+ (hp.audio.sampling_rate, sr, wavpath)
29
+
30
+ if len(wav) < hp.audio.segment_length + hp.audio.pad_short:
31
+ wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \
32
+ mode='constant', constant_values=0.0)
33
+
34
+ wav = torch.from_numpy(wav).unsqueeze(0)
35
+ mel = stft.mel_spectrogram(wav)
36
+
37
+ melpath = wavpath.replace('.wav', '.mel')
38
+ torch.save(mel, melpath)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument('-c', '--config', type=str, required=True,
44
+ help="yaml file for config.")
45
+ parser.add_argument('-d', '--data_path', type=str, required=True,
46
+ help="root directory of wav files")
47
+ args = parser.parse_args()
48
+ hp = HParam(args.config)
49
+
50
+ main(hp, args)
melgan/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ librosa
2
+ matplotlib
3
+ numpy
4
+ scipy
5
+ tensorboardX
6
+ torch
7
+ tqdm
8
+ pillow
9
+ pyyaml
melgan/trainer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import argparse
5
+
6
+ from utils.train import train
7
+ from utils.hparams import HParam
8
+ from utils.writer import MyWriter
9
+ from datasets.dataloader import create_dataloader
10
+
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('-c', '--config', type=str, required=True,
15
+ help="yaml file for configuration")
16
+ parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
17
+ help="path of checkpoint pt file to resume training")
18
+ parser.add_argument('-n', '--name', type=str, required=True,
19
+ help="name of the model for logging, saving checkpoint")
20
+ args = parser.parse_args()
21
+
22
+ hp = HParam(args.config)
23
+ with open(args.config, 'r') as f:
24
+ hp_str = ''.join(f.readlines())
25
+
26
+ pt_dir = os.path.join(hp.log.chkpt_dir, args.name)
27
+ log_dir = os.path.join(hp.log.log_dir, args.name)
28
+ os.makedirs(pt_dir, exist_ok=True)
29
+ os.makedirs(log_dir, exist_ok=True)
30
+
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format='%(asctime)s - %(levelname)s - %(message)s',
34
+ handlers=[
35
+ logging.FileHandler(os.path.join(log_dir,
36
+ '%s-%d.log' % (args.name, time.time()))),
37
+ logging.StreamHandler()
38
+ ]
39
+ )
40
+ logger = logging.getLogger()
41
+
42
+ writer = MyWriter(hp, log_dir)
43
+
44
+ assert hp.audio.hop_length == 256, \
45
+ 'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length
46
+ assert hp.data.train != '' and hp.data.validation != '', \
47
+ 'hp.data.train and hp.data.validation can\'t be empty: please fix %s' % args.config
48
+
49
+ trainloader = create_dataloader(hp, args, True)
50
+ valloader = create_dataloader(hp, args, False)
51
+
52
+ train(args, pt_dir, args.checkpoint_path, trainloader, valloader, writer, logger, hp, hp_str)
melgan/utils/audio_processing.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from scipy.signal import get_window
4
+ import librosa.util as librosa_util
5
+
6
+
7
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
8
+ n_fft=800, dtype=np.float32, norm=None):
9
+ """
10
+ # from librosa 0.6
11
+ Compute the sum-square envelope of a window function at a given hop length.
12
+
13
+ This is used to estimate modulation effects induced by windowing
14
+ observations in short-time fourier transforms.
15
+
16
+ Parameters
17
+ ----------
18
+ window : string, tuple, number, callable, or list-like
19
+ Window specification, as in `get_window`
20
+
21
+ n_frames : int > 0
22
+ The number of analysis frames
23
+
24
+ hop_length : int > 0
25
+ The number of samples to advance between frames
26
+
27
+ win_length : [optional]
28
+ The length of the window function. By default, this matches `n_fft`.
29
+
30
+ n_fft : int > 0
31
+ The length of each analysis frame.
32
+
33
+ dtype : np.dtype
34
+ The data type of the output
35
+
36
+ Returns
37
+ -------
38
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
39
+ The sum-squared envelope of the window function
40
+ """
41
+ if win_length is None:
42
+ win_length = n_fft
43
+
44
+ n = n_fft + hop_length * (n_frames - 1)
45
+ x = np.zeros(n, dtype=dtype)
46
+
47
+ # Compute the squared window at the desired length
48
+ win_sq = get_window(window, win_length, fftbins=True)
49
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
50
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
51
+
52
+ # Fill the envelope
53
+ for i in range(n_frames):
54
+ sample = i * hop_length
55
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
56
+ return x
57
+
58
+
59
+ def griffin_lim(magnitudes, stft_fn, n_iters=30):
60
+ """
61
+ PARAMS
62
+ ------
63
+ magnitudes: spectrogram magnitudes
64
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
65
+ """
66
+
67
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
68
+ angles = angles.astype(np.float32)
69
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
70
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
71
+
72
+ for i in range(n_iters):
73
+ _, angles = stft_fn.transform(signal)
74
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
75
+ return signal
76
+
77
+
78
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
79
+ """
80
+ PARAMS
81
+ ------
82
+ C: compression factor
83
+ """
84
+ return torch.log(torch.clamp(x, min=clip_val) * C)
85
+
86
+
87
+ def dynamic_range_decompression(x, C=1):
88
+ """
89
+ PARAMS
90
+ ------
91
+ C: compression factor used to compress
92
+ """
93
+ return torch.exp(x) / C
melgan/utils/hparams.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification
2
+
3
+ import os
4
+ import yaml
5
+
6
+
7
+ def load_hparam_str(hp_str):
8
+ path = 'temp-restore.yaml'
9
+ with open(path, 'w') as f:
10
+ f.write(hp_str)
11
+ ret = HParam(path)
12
+ os.remove(path)
13
+ return ret
14
+
15
+
16
+ def load_hparam(filename):
17
+ stream = open(filename, 'r')
18
+ docs = yaml.load_all(stream, Loader=yaml.Loader)
19
+ hparam_dict = dict()
20
+ for doc in docs:
21
+ for k, v in doc.items():
22
+ hparam_dict[k] = v
23
+ return hparam_dict
24
+
25
+
26
+ def merge_dict(user, default):
27
+ if isinstance(user, dict) and isinstance(default, dict):
28
+ for k, v in default.items():
29
+ if k not in user:
30
+ user[k] = v
31
+ else:
32
+ user[k] = merge_dict(user[k], v)
33
+ return user
34
+
35
+
36
+ class Dotdict(dict):
37
+ """
38
+ a dictionary that supports dot notation
39
+ as well as dictionary access notation
40
+ usage: d = DotDict() or d = DotDict({'val1':'first'})
41
+ set attributes: d.val2 = 'second' or d['val2'] = 'second'
42
+ get attributes: d.val2 or d['val2']
43
+ """
44
+ __getattr__ = dict.__getitem__
45
+ __setattr__ = dict.__setitem__
46
+ __delattr__ = dict.__delitem__
47
+
48
+ def __init__(self, dct=None):
49
+ dct = dict() if not dct else dct
50
+ for key, value in dct.items():
51
+ if hasattr(value, 'keys'):
52
+ value = Dotdict(value)
53
+ self[key] = value
54
+
55
+
56
+ class HParam(Dotdict):
57
+
58
+ def __init__(self, file):
59
+ super(Dotdict, self).__init__()
60
+ hp_dict = load_hparam(file)
61
+ hp_dotdict = Dotdict(hp_dict)
62
+ for k, v in hp_dotdict.items():
63
+ setattr(self, k, v)
64
+
65
+ __getattr__ = Dotdict.__getitem__
66
+ __setattr__ = Dotdict.__setitem__
67
+ __delattr__ = Dotdict.__delitem__
melgan/utils/plotting.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use("Agg")
3
+ import matplotlib.pylab as plt
4
+ import numpy as np
5
+
6
+
7
+ def save_figure_to_numpy(fig):
8
+ # save it to a numpy array.
9
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
10
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
11
+ data = np.transpose(data, (2, 0, 1))
12
+ return data
13
+
14
+
15
+ def plot_waveform_to_numpy(waveform):
16
+ fig, ax = plt.subplots(figsize=(12, 3))
17
+ ax.plot()
18
+ ax.plot(range(len(waveform)), waveform,
19
+ linewidth=0.1, alpha=0.7, color='blue')
20
+
21
+ plt.xlabel("Samples")
22
+ plt.ylabel("Amplitude")
23
+ plt.ylim(-1, 1)
24
+ plt.tight_layout()
25
+
26
+ fig.canvas.draw()
27
+ data = save_figure_to_numpy(fig)
28
+ plt.close()
29
+ return data
melgan/utils/stft.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) 2017, Prem Seetharaman
5
+ All rights reserved.
6
+
7
+ * Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice,
11
+ this list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice, this
14
+ list of conditions and the following disclaimer in the
15
+ documentation and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from this
19
+ software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ import numpy as np
35
+ import torch.nn.functional as F
36
+ from torch.autograd import Variable
37
+ from scipy.signal import get_window
38
+ from librosa.util import pad_center, tiny
39
+ from .audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression
40
+ from librosa.filters import mel as librosa_mel_fn
41
+
42
+
43
+ class STFT(torch.nn.Module):
44
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
45
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
46
+ window='hann'):
47
+ super(STFT, self).__init__()
48
+ self.filter_length = filter_length
49
+ self.hop_length = hop_length
50
+ self.win_length = win_length
51
+ self.window = window
52
+ self.forward_transform = None
53
+ scale = self.filter_length / self.hop_length
54
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
55
+
56
+ cutoff = int((self.filter_length / 2 + 1))
57
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
58
+ np.imag(fourier_basis[:cutoff, :])])
59
+
60
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
61
+ inverse_basis = torch.FloatTensor(
62
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
63
+
64
+ if window is not None:
65
+ assert(filter_length >= win_length)
66
+ # get window and zero center pad it to filter_length
67
+ fft_window = get_window(window, win_length, fftbins=True)
68
+ fft_window = pad_center(fft_window, filter_length)
69
+ fft_window = torch.from_numpy(fft_window).float()
70
+
71
+ # window the bases
72
+ forward_basis *= fft_window
73
+ inverse_basis *= fft_window
74
+
75
+ self.register_buffer('forward_basis', forward_basis.float())
76
+ self.register_buffer('inverse_basis', inverse_basis.float())
77
+
78
+ def transform(self, input_data):
79
+ num_batches = input_data.size(0)
80
+ num_samples = input_data.size(1)
81
+
82
+ self.num_samples = num_samples
83
+
84
+ # similar to librosa, reflect-pad the input
85
+ input_data = input_data.view(num_batches, 1, num_samples)
86
+ input_data = F.pad(
87
+ input_data.unsqueeze(1),
88
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
89
+ mode='reflect')
90
+ input_data = input_data.squeeze(1)
91
+
92
+ # https://github.com/NVIDIA/tacotron2/issues/125
93
+ forward_transform = F.conv1d(
94
+ input_data.cuda(),
95
+ Variable(self.forward_basis, requires_grad=False).cuda(),
96
+ stride=self.hop_length,
97
+ padding=0).cpu()
98
+
99
+ cutoff = int((self.filter_length / 2) + 1)
100
+ real_part = forward_transform[:, :cutoff, :]
101
+ imag_part = forward_transform[:, cutoff:, :]
102
+
103
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
104
+ phase = torch.autograd.Variable(
105
+ torch.atan2(imag_part.data, real_part.data))
106
+
107
+ return magnitude, phase
108
+
109
+ def inverse(self, magnitude, phase):
110
+ recombine_magnitude_phase = torch.cat(
111
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
112
+
113
+ inverse_transform = F.conv_transpose1d(
114
+ recombine_magnitude_phase,
115
+ Variable(self.inverse_basis, requires_grad=False),
116
+ stride=self.hop_length,
117
+ padding=0)
118
+
119
+ if self.window is not None:
120
+ window_sum = window_sumsquare(
121
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
122
+ win_length=self.win_length, n_fft=self.filter_length,
123
+ dtype=np.float32)
124
+ # remove modulation effects
125
+ approx_nonzero_indices = torch.from_numpy(
126
+ np.where(window_sum > tiny(window_sum))[0])
127
+ window_sum = torch.autograd.Variable(
128
+ torch.from_numpy(window_sum), requires_grad=False)
129
+ window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
130
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
131
+
132
+ # scale by hop ratio
133
+ inverse_transform *= float(self.filter_length) / self.hop_length
134
+
135
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
136
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
137
+
138
+ return inverse_transform
139
+
140
+ def forward(self, input_data):
141
+ self.magnitude, self.phase = self.transform(input_data)
142
+ reconstruction = self.inverse(self.magnitude, self.phase)
143
+ return reconstruction
144
+
145
+
146
+ class TacotronSTFT(torch.nn.Module):
147
+ def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
148
+ n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
149
+ mel_fmax=None):
150
+ super(TacotronSTFT, self).__init__()
151
+ self.n_mel_channels = n_mel_channels
152
+ self.sampling_rate = sampling_rate
153
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
154
+ mel_basis = librosa_mel_fn(
155
+ sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
156
+ mel_basis = torch.from_numpy(mel_basis).float()
157
+ self.register_buffer('mel_basis', mel_basis)
158
+
159
+ def spectral_normalize(self, magnitudes):
160
+ output = dynamic_range_compression(magnitudes)
161
+ return output
162
+
163
+ def spectral_de_normalize(self, magnitudes):
164
+ output = dynamic_range_decompression(magnitudes)
165
+ return output
166
+
167
+ def mel_spectrogram(self, y):
168
+ """Computes mel-spectrograms from a batch of waves
169
+ PARAMS
170
+ ------
171
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
172
+
173
+ RETURNS
174
+ -------
175
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
176
+ """
177
+ assert(torch.min(y.data) >= -1)
178
+ assert(torch.max(y.data) <= 1)
179
+
180
+ magnitudes, phases = self.stft_fn.transform(y)
181
+ magnitudes = magnitudes.data
182
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
183
+ mel_output = self.spectral_normalize(mel_output)
184
+ return mel_output
melgan/utils/train.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import tqdm
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import itertools
8
+ import traceback
9
+
10
+ from model.generator import Generator
11
+ from model.multiscale import MultiScaleDiscriminator
12
+ from .utils import get_commit_hash
13
+ from .validation import validate
14
+
15
+
16
+ def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
17
+ model_g = Generator(hp.audio.n_mel_channels).cuda()
18
+ model_d = MultiScaleDiscriminator().cuda()
19
+
20
+ optim_g = torch.optim.Adam(model_g.parameters(),
21
+ lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
22
+ optim_d = torch.optim.Adam(model_d.parameters(),
23
+ lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
24
+
25
+ githash = get_commit_hash()
26
+
27
+ init_epoch = -1
28
+ step = 0
29
+
30
+ if chkpt_path is not None:
31
+ logger.info("Resuming from checkpoint: %s" % chkpt_path)
32
+ checkpoint = torch.load(chkpt_path)
33
+ model_g.load_state_dict(checkpoint['model_g'])
34
+ model_d.load_state_dict(checkpoint['model_d'])
35
+ optim_g.load_state_dict(checkpoint['optim_g'])
36
+ optim_d.load_state_dict(checkpoint['optim_d'])
37
+ step = checkpoint['step']
38
+ init_epoch = checkpoint['epoch']
39
+
40
+ if hp_str != checkpoint['hp_str']:
41
+ logger.warning("New hparams is different from checkpoint. Will use new.")
42
+
43
+ if githash != checkpoint['githash']:
44
+ logger.warning("Code might be different: git hash is different.")
45
+ logger.warning("%s -> %s" % (checkpoint['githash'], githash))
46
+
47
+ else:
48
+ logger.info("Starting new training run.")
49
+
50
+ # this accelerates training when the size of minibatch is always consistent.
51
+ # if not consistent, it'll horribly slow down.
52
+ torch.backends.cudnn.benchmark = True
53
+
54
+ try:
55
+ model_g.train()
56
+ model_d.train()
57
+ for epoch in itertools.count(init_epoch+1):
58
+ if epoch % hp.log.validation_interval == 0:
59
+ with torch.no_grad():
60
+ validate(hp, args, model_g, model_d, valloader, writer, step)
61
+
62
+ trainloader.dataset.shuffle_mapping()
63
+ loader = tqdm.tqdm(trainloader, desc='Loading train data')
64
+ for (melG, audioG), (melD, audioD) in loader:
65
+ melG = melG.cuda()
66
+ audioG = audioG.cuda()
67
+ melD = melD.cuda()
68
+ audioD = audioD.cuda()
69
+
70
+ # generator
71
+ optim_g.zero_grad()
72
+ fake_audio = model_g(melG)[:, :, :hp.audio.segment_length]
73
+ disc_fake = model_d(fake_audio)
74
+ disc_real = model_d(audioG)
75
+ loss_g = 0.0
76
+ for (feats_fake, score_fake), (feats_real, _) in zip(disc_fake, disc_real):
77
+ loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
78
+ for feat_f, feat_r in zip(feats_fake, feats_real):
79
+ loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
80
+
81
+ loss_g.backward()
82
+ optim_g.step()
83
+
84
+ # discriminator
85
+ fake_audio = model_g(melD)[:, :, :hp.audio.segment_length]
86
+ fake_audio = fake_audio.detach()
87
+ loss_d_sum = 0.0
88
+ for _ in range(hp.train.rep_discriminator):
89
+ optim_d.zero_grad()
90
+ disc_fake = model_d(fake_audio)
91
+ disc_real = model_d(audioD)
92
+ loss_d = 0.0
93
+ for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real):
94
+ loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
95
+ loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
96
+
97
+ loss_d.backward()
98
+ optim_d.step()
99
+ loss_d_sum += loss_d
100
+
101
+ step += 1
102
+ # logging
103
+ loss_g = loss_g.item()
104
+ loss_d_avg = loss_d_sum / hp.train.rep_discriminator
105
+ loss_d_avg = loss_d_avg.item()
106
+ if any([loss_g > 1e8, math.isnan(loss_g), loss_d_avg > 1e8, math.isnan(loss_d_avg)]):
107
+ logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" % (loss_g, loss_d_avg, step))
108
+ raise Exception("Loss exploded")
109
+
110
+ if step % hp.log.summary_interval == 0:
111
+ writer.log_training(loss_g, loss_d_avg, step)
112
+ loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d_avg, step))
113
+
114
+ if epoch % hp.log.save_interval == 0:
115
+ save_path = os.path.join(pt_dir, '%s_%s_%04d.pt'
116
+ % (args.name, githash, epoch))
117
+ torch.save({
118
+ 'model_g': model_g.state_dict(),
119
+ 'model_d': model_d.state_dict(),
120
+ 'optim_g': optim_g.state_dict(),
121
+ 'optim_d': optim_d.state_dict(),
122
+ 'step': step,
123
+ 'epoch': epoch,
124
+ 'hp_str': hp_str,
125
+ 'githash': githash,
126
+ }, save_path)
127
+ logger.info("Saved checkpoint to: %s" % save_path)
128
+
129
+ except Exception as e:
130
+ logger.info("Exiting due to exception: %s" % e)
131
+ traceback.print_exc()
melgan/utils/utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import subprocess
3
+ import numpy as np
4
+ from scipy.io.wavfile import read
5
+
6
+
7
+ def get_commit_hash():
8
+ message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
9
+ return message.strip().decode('utf-8')
10
+
11
+ def read_wav_np(path):
12
+ sr, wav = read(path)
13
+
14
+ if len(wav.shape) == 2:
15
+ wav = wav[:, 0]
16
+
17
+ if wav.dtype == np.int16:
18
+ wav = wav / 32768.0
19
+ elif wav.dtype == np.int32:
20
+ wav = wav / 2147483648.0
21
+ elif wav.dtype == np.uint8:
22
+ wav = (wav - 128) / 128.0
23
+
24
+ wav = wav.astype(np.float32)
25
+
26
+ return sr, wav
melgan/utils/validation.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import torch
3
+
4
+
5
+ def validate(hp, args, generator, discriminator, valloader, writer, step):
6
+ generator.eval()
7
+ discriminator.eval()
8
+ torch.backends.cudnn.benchmark = False
9
+
10
+ loader = tqdm.tqdm(valloader, desc='Validation loop')
11
+ loss_g_sum = 0.0
12
+ loss_d_sum = 0.0
13
+ for mel, audio in loader:
14
+ mel = mel.cuda()
15
+ audio = audio.cuda()
16
+
17
+ # generator
18
+ fake_audio = generator(mel)
19
+ disc_fake = discriminator(fake_audio[:, :, :audio.size(2)])
20
+ disc_real = discriminator(audio)
21
+ loss_g = 0.0
22
+ loss_d = 0.0
23
+ for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real):
24
+ loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
25
+ for feat_f, feat_r in zip(feats_fake, feats_real):
26
+ loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
27
+ loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
28
+ loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
29
+
30
+ loss_g_sum += loss_g.item()
31
+ loss_d_sum += loss_d.item()
32
+
33
+ loss_g_avg = loss_g_sum / len(valloader.dataset)
34
+ loss_d_avg = loss_d_sum / len(valloader.dataset)
35
+
36
+ audio = audio[0][0].cpu().detach().numpy()
37
+ fake_audio = fake_audio[0][0].cpu().detach().numpy()
38
+
39
+ writer.log_validation(loss_g_avg, loss_d_avg, generator, discriminator, audio, fake_audio, step)
40
+
41
+ torch.backends.cudnn.benchmark = True
melgan/utils/writer.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorboardX import SummaryWriter
2
+
3
+ from .plotting import plot_waveform_to_numpy
4
+
5
+
6
+ class MyWriter(SummaryWriter):
7
+ def __init__(self, hp, logdir):
8
+ super(MyWriter, self).__init__(logdir)
9
+ self.sample_rate = hp.audio.sampling_rate
10
+ self.is_first = True
11
+
12
+ def log_training(self, g_loss, d_loss, step):
13
+ self.add_scalar('train.g_loss', g_loss, step)
14
+ self.add_scalar('train.d_loss', d_loss, step)
15
+
16
+ def log_validation(self, g_loss, d_loss, generator, discriminator, target, prediction, step):
17
+ self.add_scalar('validation.g_loss', g_loss, step)
18
+ self.add_scalar('validation.d_loss', d_loss, step)
19
+
20
+ self.add_audio('raw_audio_predicted', prediction, step, self.sample_rate)
21
+ self.add_image('waveform_predicted', plot_waveform_to_numpy(prediction), step)
22
+
23
+ self.log_histogram(generator, step)
24
+ self.log_histogram(discriminator, step)
25
+
26
+ if self.is_first:
27
+ self.add_audio('raw_audio_target', target, step, self.sample_rate)
28
+ self.add_image('waveform_target', plot_waveform_to_numpy(target), step)
29
+ self.is_first = False
30
+
31
+ def log_histogram(self, model, step):
32
+ for tag, value in model.named_parameters():
33
+ self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)