Spaces:
Sleeping
Sleeping
Upload MelGAN model
Browse filesrepository cloned from:
https://github.com/seungwonpark/melgan/tree/master
- melgan/.gitignore +119 -0
- melgan/LICENSE +29 -0
- melgan/README.md +83 -0
- melgan/assets/gd.png +0 -0
- melgan/assets/lj-tensorboard-v0.3-alpha.png +0 -0
- melgan/assets/lj-tensorboard.png +0 -0
- melgan/config/default.yaml +34 -0
- melgan/datasets/dataloader.py +67 -0
- melgan/hubconf.py +41 -0
- melgan/inference.py +49 -0
- melgan/model/discriminator.py +64 -0
- melgan/model/generator.py +99 -0
- melgan/model/identity.py +12 -0
- melgan/model/multiscale.py +29 -0
- melgan/model/res_stack.py +36 -0
- melgan/preprocess.py +50 -0
- melgan/requirements.txt +9 -0
- melgan/trainer.py +52 -0
- melgan/utils/audio_processing.py +93 -0
- melgan/utils/hparams.py +67 -0
- melgan/utils/plotting.py +29 -0
- melgan/utils/stft.py +184 -0
- melgan/utils/train.py +131 -0
- melgan/utils/utils.py +26 -0
- melgan/utils/validation.py +41 -0
- melgan/utils/writer.py +33 -0
melgan/.gitignore
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IDE configuration
|
2 |
+
.idea/
|
3 |
+
|
4 |
+
# configuration
|
5 |
+
config/*
|
6 |
+
!config/default.yaml
|
7 |
+
temp-restore.yaml
|
8 |
+
|
9 |
+
# logs, checkpoints
|
10 |
+
chkpt/
|
11 |
+
logs/
|
12 |
+
|
13 |
+
# just a temporary folder
|
14 |
+
temp/
|
15 |
+
|
16 |
+
# Byte-compiled / optimized / DLL files
|
17 |
+
__pycache__/
|
18 |
+
*.py[cod]
|
19 |
+
*$py.class
|
20 |
+
|
21 |
+
# C extensions
|
22 |
+
*.so
|
23 |
+
|
24 |
+
# Distribution / packaging
|
25 |
+
.Python
|
26 |
+
build/
|
27 |
+
develop-eggs/
|
28 |
+
dist/
|
29 |
+
downloads/
|
30 |
+
eggs/
|
31 |
+
.eggs/
|
32 |
+
lib/
|
33 |
+
lib64/
|
34 |
+
parts/
|
35 |
+
sdist/
|
36 |
+
var/
|
37 |
+
wheels/
|
38 |
+
*.egg-info/
|
39 |
+
.installed.cfg
|
40 |
+
*.egg
|
41 |
+
MANIFEST
|
42 |
+
|
43 |
+
# PyInstaller
|
44 |
+
# Usually these files are written by a python script from a template
|
45 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
46 |
+
*.manifest
|
47 |
+
*.spec
|
48 |
+
|
49 |
+
# Installer logs
|
50 |
+
pip-log.txt
|
51 |
+
pip-delete-this-directory.txt
|
52 |
+
|
53 |
+
# Unit test / coverage reports
|
54 |
+
htmlcov/
|
55 |
+
.tox/
|
56 |
+
.coverage
|
57 |
+
.coverage.*
|
58 |
+
.cache
|
59 |
+
nosetests.xml
|
60 |
+
coverage.xml
|
61 |
+
*.cover
|
62 |
+
.hypothesis/
|
63 |
+
.pytest_cache/
|
64 |
+
|
65 |
+
# Translations
|
66 |
+
*.mo
|
67 |
+
*.pot
|
68 |
+
|
69 |
+
# Django stuff:
|
70 |
+
*.log
|
71 |
+
local_settings.py
|
72 |
+
db.sqlite3
|
73 |
+
|
74 |
+
# Flask stuff:
|
75 |
+
instance/
|
76 |
+
.webassets-cache
|
77 |
+
|
78 |
+
# Scrapy stuff:
|
79 |
+
.scrapy
|
80 |
+
|
81 |
+
# Sphinx documentation
|
82 |
+
docs/_build/
|
83 |
+
|
84 |
+
# PyBuilder
|
85 |
+
target/
|
86 |
+
|
87 |
+
# Jupyter Notebook
|
88 |
+
.ipynb_checkpoints
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
.python-version
|
92 |
+
|
93 |
+
# celery beat schedule file
|
94 |
+
celerybeat-schedule
|
95 |
+
|
96 |
+
# SageMath parsed files
|
97 |
+
*.sage.py
|
98 |
+
|
99 |
+
# Environments
|
100 |
+
.env
|
101 |
+
.venv
|
102 |
+
env/
|
103 |
+
venv/
|
104 |
+
ENV/
|
105 |
+
env.bak/
|
106 |
+
venv.bak/
|
107 |
+
|
108 |
+
# Spyder project settings
|
109 |
+
.spyderproject
|
110 |
+
.spyproject
|
111 |
+
|
112 |
+
# Rope project settings
|
113 |
+
.ropeproject
|
114 |
+
|
115 |
+
# mkdocs documentation
|
116 |
+
/site
|
117 |
+
|
118 |
+
# mypy
|
119 |
+
.mypy_cache/
|
melgan/LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2019, Seungwon Park 박승원
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
3. Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
melgan/README.md
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MelGAN
|
2 |
+
Unofficial PyTorch implementation of [MelGAN vocoder](https://arxiv.org/abs/1910.06711)
|
3 |
+
|
4 |
+
## Key Features
|
5 |
+
|
6 |
+
- MelGAN is lighter, faster, and better at generalizing to unseen speakers than [WaveGlow](https://github.com/NVIDIA/waveglow).
|
7 |
+
- This repository use identical mel-spectrogram function from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2), so this can be directly used to convert output from NVIDIA's tacotron2 into raw-audio.
|
8 |
+
- Pretrained model on LJSpeech-1.1 via [PyTorch Hub](https://pytorch.org/hub).
|
9 |
+
|
10 |
+
data:image/s3,"s3://crabby-images/d22eb/d22eb53d82f95a6c8371627bcf72c498e47b811c" alt=""
|
11 |
+
|
12 |
+
## Prerequisites
|
13 |
+
|
14 |
+
Tested on Python 3.6
|
15 |
+
```bash
|
16 |
+
pip install -r requirements.txt
|
17 |
+
```
|
18 |
+
|
19 |
+
## Prepare Dataset
|
20 |
+
|
21 |
+
- Download dataset for training. This can be any wav files with sample rate 22050Hz. (e.g. LJSpeech was used in paper)
|
22 |
+
- preprocess: `python preprocess.py -c config/default.yaml -d [data's root path]`
|
23 |
+
- Edit configuration `yaml` file
|
24 |
+
|
25 |
+
## Train & Tensorboard
|
26 |
+
|
27 |
+
- `python trainer.py -c [config yaml file] -n [name of the run]`
|
28 |
+
- `cp config/default.yaml config/config.yaml` and then edit `config.yaml`
|
29 |
+
- Write down the root path of train/validation files to 2nd/3rd line.
|
30 |
+
- Each path should contain pairs of `*.wav` with corresponding (preprocessed) `*.mel` file.
|
31 |
+
- The data loader parses list of files within the path recursively.
|
32 |
+
- `tensorboard --logdir logs/`
|
33 |
+
|
34 |
+
## Pretrained model
|
35 |
+
|
36 |
+
Try with Google Colab: TODO
|
37 |
+
|
38 |
+
```python
|
39 |
+
import torch
|
40 |
+
vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
|
41 |
+
vocoder.eval()
|
42 |
+
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here
|
43 |
+
|
44 |
+
if torch.cuda.is_available():
|
45 |
+
vocoder = vocoder.cuda()
|
46 |
+
mel = mel.cuda()
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
audio = vocoder.inference(mel)
|
50 |
+
```
|
51 |
+
|
52 |
+
## Inference
|
53 |
+
|
54 |
+
- `python inference.py -p [checkpoint path] -i [input mel path]`
|
55 |
+
|
56 |
+
## Results
|
57 |
+
|
58 |
+
See audio samples at: http://swpark.me/melgan/.
|
59 |
+
Model was trained at V100 GPU for 14 days using LJSpeech-1.1.
|
60 |
+
|
61 |
+
data:image/s3,"s3://crabby-images/e911d/e911db207966ca343e06067e53aa4bc12f3da661" alt=""
|
62 |
+
|
63 |
+
|
64 |
+
## Implementation Authors
|
65 |
+
|
66 |
+
- [Seungwon Park](http://swpark.me) @ MINDsLab Inc. ([email protected], [email protected])
|
67 |
+
- Myunchul Joe @ MINDsLab Inc.
|
68 |
+
- [Rishikesh](https://github.com/rishikksh20) @ DeepSync Technologies Pvt Ltd.
|
69 |
+
|
70 |
+
## License
|
71 |
+
|
72 |
+
BSD 3-Clause License.
|
73 |
+
|
74 |
+
- [utils/stft.py](./utils/stft.py) by Prem Seetharaman (BSD 3-Clause License)
|
75 |
+
- [datasets/mel2samp.py](./datasets/mel2samp.py) from https://github.com/NVIDIA/waveglow (BSD 3-Clause License)
|
76 |
+
- [utils/hparams.py](./utils/hparams.py) from https://github.com/HarryVolek/PyTorch_Speaker_Verification (No License specified)
|
77 |
+
|
78 |
+
## Useful resources
|
79 |
+
|
80 |
+
- [How to Train a GAN? Tips and tricks to make GANs work](https://github.com/soumith/ganhacks) by Soumith Chintala
|
81 |
+
- [Official MelGAN implementation by original authors](https://github.com/descriptinc/melgan-neurips)
|
82 |
+
- [Reproduction of MelGAN - NeurIPS 2019 Reproducibility Challenge (Ablation Track)](https://openreview.net/pdf?id=9jTbNbBNw0) by Yifei Zhao, Yichao Yang, and Yang Gao
|
83 |
+
- "replacing the average pooling layer with max pooling layer and replacing reflection padding with replication padding improves the performance significantly, while combining them produces worse results"
|
melgan/assets/gd.png
ADDED
![]() |
melgan/assets/lj-tensorboard-v0.3-alpha.png
ADDED
![]() |
melgan/assets/lj-tensorboard.png
ADDED
![]() |
melgan/config/default.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data: # root path of train/validation data (either relative/absoulte path is ok)
|
2 |
+
train: ''
|
3 |
+
validation: ''
|
4 |
+
---
|
5 |
+
train:
|
6 |
+
rep_discriminator: 1
|
7 |
+
num_workers: 32
|
8 |
+
batch_size: 16
|
9 |
+
optimizer: 'adam'
|
10 |
+
adam:
|
11 |
+
lr: 0.0001
|
12 |
+
beta1: 0.5
|
13 |
+
beta2: 0.9
|
14 |
+
---
|
15 |
+
audio:
|
16 |
+
n_mel_channels: 80
|
17 |
+
segment_length: 16000
|
18 |
+
pad_short: 2000
|
19 |
+
filter_length: 1024
|
20 |
+
hop_length: 256 # WARNING: this can't be changed.
|
21 |
+
win_length: 1024
|
22 |
+
sampling_rate: 22050
|
23 |
+
mel_fmin: 0.0
|
24 |
+
mel_fmax: 8000.0
|
25 |
+
---
|
26 |
+
model:
|
27 |
+
feat_match: 10.0
|
28 |
+
---
|
29 |
+
log:
|
30 |
+
summary_interval: 1
|
31 |
+
validation_interval: 5
|
32 |
+
save_interval: 25
|
33 |
+
chkpt_dir: 'chkpt'
|
34 |
+
log_dir: 'logs'
|
melgan/datasets/dataloader.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from torch.utils.data import Dataset, DataLoader
|
7 |
+
|
8 |
+
from utils.utils import read_wav_np
|
9 |
+
|
10 |
+
|
11 |
+
def create_dataloader(hp, args, train):
|
12 |
+
dataset = MelFromDisk(hp, args, train)
|
13 |
+
|
14 |
+
if train:
|
15 |
+
return DataLoader(dataset=dataset, batch_size=hp.train.batch_size, shuffle=True,
|
16 |
+
num_workers=hp.train.num_workers, pin_memory=True, drop_last=True)
|
17 |
+
else:
|
18 |
+
return DataLoader(dataset=dataset, batch_size=1, shuffle=False,
|
19 |
+
num_workers=hp.train.num_workers, pin_memory=True, drop_last=False)
|
20 |
+
|
21 |
+
|
22 |
+
class MelFromDisk(Dataset):
|
23 |
+
def __init__(self, hp, args, train):
|
24 |
+
self.hp = hp
|
25 |
+
self.args = args
|
26 |
+
self.train = train
|
27 |
+
self.path = hp.data.train if train else hp.data.validation
|
28 |
+
self.wav_list = glob.glob(os.path.join(self.path, '**', '*.wav'), recursive=True)
|
29 |
+
self.mel_segment_length = hp.audio.segment_length // hp.audio.hop_length + 2
|
30 |
+
self.mapping = [i for i in range(len(self.wav_list))]
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.wav_list)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
if self.train:
|
37 |
+
idx1 = idx
|
38 |
+
idx2 = self.mapping[idx1]
|
39 |
+
return self.my_getitem(idx1), self.my_getitem(idx2)
|
40 |
+
else:
|
41 |
+
return self.my_getitem(idx)
|
42 |
+
|
43 |
+
def shuffle_mapping(self):
|
44 |
+
random.shuffle(self.mapping)
|
45 |
+
|
46 |
+
def my_getitem(self, idx):
|
47 |
+
wavpath = self.wav_list[idx]
|
48 |
+
melpath = wavpath.replace('.wav', '.mel')
|
49 |
+
sr, audio = read_wav_np(wavpath)
|
50 |
+
if len(audio) < self.hp.audio.segment_length + self.hp.audio.pad_short:
|
51 |
+
audio = np.pad(audio, (0, self.hp.audio.segment_length + self.hp.audio.pad_short - len(audio)), \
|
52 |
+
mode='constant', constant_values=0.0)
|
53 |
+
|
54 |
+
audio = torch.from_numpy(audio).unsqueeze(0)
|
55 |
+
mel = torch.load(melpath).squeeze(0)
|
56 |
+
|
57 |
+
if self.train:
|
58 |
+
max_mel_start = mel.size(1) - self.mel_segment_length
|
59 |
+
mel_start = random.randint(0, max_mel_start)
|
60 |
+
mel_end = mel_start + self.mel_segment_length
|
61 |
+
mel = mel[:, mel_start:mel_end]
|
62 |
+
|
63 |
+
audio_start = mel_start * self.hp.audio.hop_length
|
64 |
+
audio = audio[:, audio_start:audio_start+self.hp.audio.segment_length]
|
65 |
+
|
66 |
+
audio = audio + (1/32768) * torch.randn_like(audio)
|
67 |
+
return mel, audio
|
melgan/hubconf.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dependencies = ['torch']
|
2 |
+
import torch
|
3 |
+
from model.generator import Generator
|
4 |
+
|
5 |
+
model_params = {
|
6 |
+
'nvidia_tacotron2_LJ11_epoch6400': {
|
7 |
+
'mel_channel': 80,
|
8 |
+
'model_url': 'https://github.com/seungwonpark/melgan/releases/download/v0.3-alpha/nvidia_tacotron2_LJ11_epoch6400.pt',
|
9 |
+
},
|
10 |
+
}
|
11 |
+
|
12 |
+
|
13 |
+
def melgan(model_name='nvidia_tacotron2_LJ11_epoch6400', pretrained=True, progress=True):
|
14 |
+
params = model_params[model_name]
|
15 |
+
model = Generator(params['mel_channel'])
|
16 |
+
|
17 |
+
if pretrained:
|
18 |
+
state_dict = torch.hub.load_state_dict_from_url(params['model_url'],
|
19 |
+
progress=progress)
|
20 |
+
model.load_state_dict(state_dict['model_g'])
|
21 |
+
|
22 |
+
model.eval(inference=True)
|
23 |
+
|
24 |
+
return model
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
vocoder = torch.hub.load('seungwonpark/melgan', 'melgan')
|
29 |
+
mel = torch.randn(1, 80, 234) # use your own mel-spectrogram here
|
30 |
+
|
31 |
+
print('Input mel-spectrogram shape: {}'.format(mel.shape))
|
32 |
+
|
33 |
+
if torch.cuda.is_available():
|
34 |
+
print('Moving data & model to GPU')
|
35 |
+
vocoder = vocoder.cuda()
|
36 |
+
mel = mel.cuda()
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
audio = vocoder.inference(mel)
|
40 |
+
|
41 |
+
print('Output audio shape: {}'.format(audio.shape))
|
melgan/inference.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import tqdm
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
from scipy.io.wavfile import write
|
7 |
+
|
8 |
+
from model.generator import Generator
|
9 |
+
from utils.hparams import HParam, load_hparam_str
|
10 |
+
|
11 |
+
MAX_WAV_VALUE = 32768.0
|
12 |
+
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
checkpoint = torch.load(args.checkpoint_path)
|
16 |
+
if args.config is not None:
|
17 |
+
hp = HParam(args.config)
|
18 |
+
else:
|
19 |
+
hp = load_hparam_str(checkpoint['hp_str'])
|
20 |
+
|
21 |
+
model = Generator(hp.audio.n_mel_channels).cuda()
|
22 |
+
model.load_state_dict(checkpoint['model_g'])
|
23 |
+
model.eval(inference=False)
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
|
27 |
+
mel = torch.load(melpath)
|
28 |
+
if len(mel.shape) == 2:
|
29 |
+
mel = mel.unsqueeze(0)
|
30 |
+
mel = mel.cuda()
|
31 |
+
|
32 |
+
audio = model.inference(mel)
|
33 |
+
audio = audio.cpu().detach().numpy()
|
34 |
+
|
35 |
+
out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
|
36 |
+
write(out_path, hp.audio.sampling_rate, audio)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == '__main__':
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
parser.add_argument('-c', '--config', type=str, default=None,
|
42 |
+
help="yaml file for config. will use hp_str from checkpoint if not given.")
|
43 |
+
parser.add_argument('-p', '--checkpoint_path', type=str, required=True,
|
44 |
+
help="path of checkpoint pt file for evaluation")
|
45 |
+
parser.add_argument('-i', '--input_folder', type=str, required=True,
|
46 |
+
help="directory of mel-spectrograms to invert into raw audio. ")
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
main(args)
|
melgan/model/discriminator.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Discriminator(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(Discriminator, self).__init__()
|
9 |
+
|
10 |
+
self.discriminator = nn.ModuleList([
|
11 |
+
nn.Sequential(
|
12 |
+
nn.ReflectionPad1d(7),
|
13 |
+
nn.utils.weight_norm(nn.Conv1d(1, 16, kernel_size=15, stride=1)),
|
14 |
+
nn.LeakyReLU(0.2, inplace=True),
|
15 |
+
),
|
16 |
+
nn.Sequential(
|
17 |
+
nn.utils.weight_norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, padding=20, groups=4)),
|
18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
19 |
+
),
|
20 |
+
nn.Sequential(
|
21 |
+
nn.utils.weight_norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, padding=20, groups=16)),
|
22 |
+
nn.LeakyReLU(0.2, inplace=True),
|
23 |
+
),
|
24 |
+
nn.Sequential(
|
25 |
+
nn.utils.weight_norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, padding=20, groups=64)),
|
26 |
+
nn.LeakyReLU(0.2, inplace=True),
|
27 |
+
),
|
28 |
+
nn.Sequential(
|
29 |
+
nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256)),
|
30 |
+
nn.LeakyReLU(0.2, inplace=True),
|
31 |
+
),
|
32 |
+
nn.Sequential(
|
33 |
+
nn.utils.weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=2)),
|
34 |
+
nn.LeakyReLU(0.2, inplace=True),
|
35 |
+
),
|
36 |
+
nn.utils.weight_norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1)),
|
37 |
+
])
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
'''
|
41 |
+
returns: (list of 6 features, discriminator score)
|
42 |
+
we directly predict score without last sigmoid function
|
43 |
+
since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076)
|
44 |
+
'''
|
45 |
+
features = list()
|
46 |
+
for module in self.discriminator:
|
47 |
+
x = module(x)
|
48 |
+
features.append(x)
|
49 |
+
return features[:-1], features[-1]
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
model = Discriminator()
|
54 |
+
|
55 |
+
x = torch.randn(3, 1, 22050)
|
56 |
+
print(x.shape)
|
57 |
+
|
58 |
+
features, score = model(x)
|
59 |
+
for feat in features:
|
60 |
+
print(feat.shape)
|
61 |
+
print(score.shape)
|
62 |
+
|
63 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
64 |
+
print(pytorch_total_params)
|
melgan/model/generator.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .res_stack import ResStack
|
6 |
+
# from res_stack import ResStack
|
7 |
+
|
8 |
+
MAX_WAV_VALUE = 32768.0
|
9 |
+
|
10 |
+
|
11 |
+
class Generator(nn.Module):
|
12 |
+
def __init__(self, mel_channel):
|
13 |
+
super(Generator, self).__init__()
|
14 |
+
self.mel_channel = mel_channel
|
15 |
+
|
16 |
+
self.generator = nn.Sequential(
|
17 |
+
nn.ReflectionPad1d(3),
|
18 |
+
nn.utils.weight_norm(nn.Conv1d(mel_channel, 512, kernel_size=7, stride=1)),
|
19 |
+
|
20 |
+
nn.LeakyReLU(0.2),
|
21 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(512, 256, kernel_size=16, stride=8, padding=4)),
|
22 |
+
|
23 |
+
ResStack(256),
|
24 |
+
|
25 |
+
nn.LeakyReLU(0.2),
|
26 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(256, 128, kernel_size=16, stride=8, padding=4)),
|
27 |
+
|
28 |
+
ResStack(128),
|
29 |
+
|
30 |
+
nn.LeakyReLU(0.2),
|
31 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(128, 64, kernel_size=4, stride=2, padding=1)),
|
32 |
+
|
33 |
+
ResStack(64),
|
34 |
+
|
35 |
+
nn.LeakyReLU(0.2),
|
36 |
+
nn.utils.weight_norm(nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)),
|
37 |
+
|
38 |
+
ResStack(32),
|
39 |
+
|
40 |
+
nn.LeakyReLU(0.2),
|
41 |
+
nn.ReflectionPad1d(3),
|
42 |
+
nn.utils.weight_norm(nn.Conv1d(32, 1, kernel_size=7, stride=1)),
|
43 |
+
nn.Tanh(),
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, mel):
|
47 |
+
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
|
48 |
+
return self.generator(mel)
|
49 |
+
|
50 |
+
def eval(self, inference=False):
|
51 |
+
super(Generator, self).eval()
|
52 |
+
|
53 |
+
# don't remove weight norm while validation in training loop
|
54 |
+
if inference:
|
55 |
+
self.remove_weight_norm()
|
56 |
+
|
57 |
+
def remove_weight_norm(self):
|
58 |
+
for idx, layer in enumerate(self.generator):
|
59 |
+
if len(layer.state_dict()) != 0:
|
60 |
+
try:
|
61 |
+
nn.utils.remove_weight_norm(layer)
|
62 |
+
except:
|
63 |
+
layer.remove_weight_norm()
|
64 |
+
|
65 |
+
def inference(self, mel):
|
66 |
+
hop_length = 256
|
67 |
+
# pad input mel with zeros to cut artifact
|
68 |
+
# see https://github.com/seungwonpark/melgan/issues/8
|
69 |
+
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
|
70 |
+
mel = torch.cat((mel, zero), dim=2)
|
71 |
+
|
72 |
+
audio = self.forward(mel)
|
73 |
+
audio = audio.squeeze() # collapse all dimension except time axis
|
74 |
+
audio = audio[:-(hop_length*10)]
|
75 |
+
audio = MAX_WAV_VALUE * audio
|
76 |
+
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
|
77 |
+
audio = audio.short()
|
78 |
+
|
79 |
+
return audio
|
80 |
+
|
81 |
+
|
82 |
+
'''
|
83 |
+
to run this, fix
|
84 |
+
from . import ResStack
|
85 |
+
into
|
86 |
+
from res_stack import ResStack
|
87 |
+
'''
|
88 |
+
if __name__ == '__main__':
|
89 |
+
model = Generator(80)
|
90 |
+
|
91 |
+
x = torch.randn(3, 80, 10)
|
92 |
+
print(x.shape)
|
93 |
+
|
94 |
+
y = model(x)
|
95 |
+
print(y.shape)
|
96 |
+
assert y.shape == torch.Size([3, 1, 2560])
|
97 |
+
|
98 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
99 |
+
print(pytorch_total_params)
|
melgan/model/identity.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Identity(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(Identity, self).__init__()
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
return x
|
12 |
+
|
melgan/model/multiscale.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .discriminator import Discriminator
|
6 |
+
from .identity import Identity
|
7 |
+
|
8 |
+
|
9 |
+
class MultiScaleDiscriminator(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super(MultiScaleDiscriminator, self).__init__()
|
12 |
+
|
13 |
+
self.discriminators = nn.ModuleList(
|
14 |
+
[Discriminator() for _ in range(3)]
|
15 |
+
)
|
16 |
+
|
17 |
+
self.pooling = nn.ModuleList(
|
18 |
+
[Identity()] +
|
19 |
+
[nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)]
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
ret = list()
|
24 |
+
|
25 |
+
for pool, disc in zip(self.pooling, self.discriminators):
|
26 |
+
x = pool(x)
|
27 |
+
ret.append(disc(x))
|
28 |
+
|
29 |
+
return ret # [(feat, score), (feat, score), (feat, score)]
|
melgan/model/res_stack.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ResStack(nn.Module):
|
8 |
+
def __init__(self, channel):
|
9 |
+
super(ResStack, self).__init__()
|
10 |
+
|
11 |
+
self.blocks = nn.ModuleList([
|
12 |
+
nn.Sequential(
|
13 |
+
nn.LeakyReLU(0.2),
|
14 |
+
nn.ReflectionPad1d(3**i),
|
15 |
+
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=3, dilation=3**i)),
|
16 |
+
nn.LeakyReLU(0.2),
|
17 |
+
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
18 |
+
)
|
19 |
+
for i in range(3)
|
20 |
+
])
|
21 |
+
|
22 |
+
self.shortcuts = nn.ModuleList([
|
23 |
+
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
24 |
+
for i in range(3)
|
25 |
+
])
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
for block, shortcut in zip(self.blocks, self.shortcuts):
|
29 |
+
x = shortcut(x) + block(x)
|
30 |
+
return x
|
31 |
+
|
32 |
+
def remove_weight_norm(self):
|
33 |
+
for block, shortcut in zip(self.blocks, self.shortcuts):
|
34 |
+
nn.utils.remove_weight_norm(block[2])
|
35 |
+
nn.utils.remove_weight_norm(block[4])
|
36 |
+
nn.utils.remove_weight_norm(shortcut)
|
melgan/preprocess.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import tqdm
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from utils.stft import TacotronSTFT
|
9 |
+
from utils.hparams import HParam
|
10 |
+
from utils.utils import read_wav_np
|
11 |
+
|
12 |
+
|
13 |
+
def main(hp, args):
|
14 |
+
stft = TacotronSTFT(filter_length=hp.audio.filter_length,
|
15 |
+
hop_length=hp.audio.hop_length,
|
16 |
+
win_length=hp.audio.win_length,
|
17 |
+
n_mel_channels=hp.audio.n_mel_channels,
|
18 |
+
sampling_rate=hp.audio.sampling_rate,
|
19 |
+
mel_fmin=hp.audio.mel_fmin,
|
20 |
+
mel_fmax=hp.audio.mel_fmax)
|
21 |
+
|
22 |
+
wav_files = glob.glob(os.path.join(args.data_path, '**', '*.wav'), recursive=True)
|
23 |
+
|
24 |
+
for wavpath in tqdm.tqdm(wav_files, desc='preprocess wav to mel'):
|
25 |
+
sr, wav = read_wav_np(wavpath)
|
26 |
+
assert sr == hp.audio.sampling_rate, \
|
27 |
+
"sample rate mismatch. expected %d, got %d at %s" % \
|
28 |
+
(hp.audio.sampling_rate, sr, wavpath)
|
29 |
+
|
30 |
+
if len(wav) < hp.audio.segment_length + hp.audio.pad_short:
|
31 |
+
wav = np.pad(wav, (0, hp.audio.segment_length + hp.audio.pad_short - len(wav)), \
|
32 |
+
mode='constant', constant_values=0.0)
|
33 |
+
|
34 |
+
wav = torch.from_numpy(wav).unsqueeze(0)
|
35 |
+
mel = stft.mel_spectrogram(wav)
|
36 |
+
|
37 |
+
melpath = wavpath.replace('.wav', '.mel')
|
38 |
+
torch.save(mel, melpath)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument('-c', '--config', type=str, required=True,
|
44 |
+
help="yaml file for config.")
|
45 |
+
parser.add_argument('-d', '--data_path', type=str, required=True,
|
46 |
+
help="root directory of wav files")
|
47 |
+
args = parser.parse_args()
|
48 |
+
hp = HParam(args.config)
|
49 |
+
|
50 |
+
main(hp, args)
|
melgan/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa
|
2 |
+
matplotlib
|
3 |
+
numpy
|
4 |
+
scipy
|
5 |
+
tensorboardX
|
6 |
+
torch
|
7 |
+
tqdm
|
8 |
+
pillow
|
9 |
+
pyyaml
|
melgan/trainer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from utils.train import train
|
7 |
+
from utils.hparams import HParam
|
8 |
+
from utils.writer import MyWriter
|
9 |
+
from datasets.dataloader import create_dataloader
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == '__main__':
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('-c', '--config', type=str, required=True,
|
15 |
+
help="yaml file for configuration")
|
16 |
+
parser.add_argument('-p', '--checkpoint_path', type=str, default=None,
|
17 |
+
help="path of checkpoint pt file to resume training")
|
18 |
+
parser.add_argument('-n', '--name', type=str, required=True,
|
19 |
+
help="name of the model for logging, saving checkpoint")
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
hp = HParam(args.config)
|
23 |
+
with open(args.config, 'r') as f:
|
24 |
+
hp_str = ''.join(f.readlines())
|
25 |
+
|
26 |
+
pt_dir = os.path.join(hp.log.chkpt_dir, args.name)
|
27 |
+
log_dir = os.path.join(hp.log.log_dir, args.name)
|
28 |
+
os.makedirs(pt_dir, exist_ok=True)
|
29 |
+
os.makedirs(log_dir, exist_ok=True)
|
30 |
+
|
31 |
+
logging.basicConfig(
|
32 |
+
level=logging.INFO,
|
33 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
34 |
+
handlers=[
|
35 |
+
logging.FileHandler(os.path.join(log_dir,
|
36 |
+
'%s-%d.log' % (args.name, time.time()))),
|
37 |
+
logging.StreamHandler()
|
38 |
+
]
|
39 |
+
)
|
40 |
+
logger = logging.getLogger()
|
41 |
+
|
42 |
+
writer = MyWriter(hp, log_dir)
|
43 |
+
|
44 |
+
assert hp.audio.hop_length == 256, \
|
45 |
+
'hp.audio.hop_length must be equal to 256, got %d' % hp.audio.hop_length
|
46 |
+
assert hp.data.train != '' and hp.data.validation != '', \
|
47 |
+
'hp.data.train and hp.data.validation can\'t be empty: please fix %s' % args.config
|
48 |
+
|
49 |
+
trainloader = create_dataloader(hp, args, True)
|
50 |
+
valloader = create_dataloader(hp, args, False)
|
51 |
+
|
52 |
+
train(args, pt_dir, args.checkpoint_path, trainloader, valloader, writer, logger, hp, hp_str)
|
melgan/utils/audio_processing.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from scipy.signal import get_window
|
4 |
+
import librosa.util as librosa_util
|
5 |
+
|
6 |
+
|
7 |
+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
8 |
+
n_fft=800, dtype=np.float32, norm=None):
|
9 |
+
"""
|
10 |
+
# from librosa 0.6
|
11 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
12 |
+
|
13 |
+
This is used to estimate modulation effects induced by windowing
|
14 |
+
observations in short-time fourier transforms.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
window : string, tuple, number, callable, or list-like
|
19 |
+
Window specification, as in `get_window`
|
20 |
+
|
21 |
+
n_frames : int > 0
|
22 |
+
The number of analysis frames
|
23 |
+
|
24 |
+
hop_length : int > 0
|
25 |
+
The number of samples to advance between frames
|
26 |
+
|
27 |
+
win_length : [optional]
|
28 |
+
The length of the window function. By default, this matches `n_fft`.
|
29 |
+
|
30 |
+
n_fft : int > 0
|
31 |
+
The length of each analysis frame.
|
32 |
+
|
33 |
+
dtype : np.dtype
|
34 |
+
The data type of the output
|
35 |
+
|
36 |
+
Returns
|
37 |
+
-------
|
38 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
39 |
+
The sum-squared envelope of the window function
|
40 |
+
"""
|
41 |
+
if win_length is None:
|
42 |
+
win_length = n_fft
|
43 |
+
|
44 |
+
n = n_fft + hop_length * (n_frames - 1)
|
45 |
+
x = np.zeros(n, dtype=dtype)
|
46 |
+
|
47 |
+
# Compute the squared window at the desired length
|
48 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
49 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
50 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
51 |
+
|
52 |
+
# Fill the envelope
|
53 |
+
for i in range(n_frames):
|
54 |
+
sample = i * hop_length
|
55 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
60 |
+
"""
|
61 |
+
PARAMS
|
62 |
+
------
|
63 |
+
magnitudes: spectrogram magnitudes
|
64 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
65 |
+
"""
|
66 |
+
|
67 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
68 |
+
angles = angles.astype(np.float32)
|
69 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
70 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
71 |
+
|
72 |
+
for i in range(n_iters):
|
73 |
+
_, angles = stft_fn.transform(signal)
|
74 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
75 |
+
return signal
|
76 |
+
|
77 |
+
|
78 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
79 |
+
"""
|
80 |
+
PARAMS
|
81 |
+
------
|
82 |
+
C: compression factor
|
83 |
+
"""
|
84 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
85 |
+
|
86 |
+
|
87 |
+
def dynamic_range_decompression(x, C=1):
|
88 |
+
"""
|
89 |
+
PARAMS
|
90 |
+
------
|
91 |
+
C: compression factor used to compress
|
92 |
+
"""
|
93 |
+
return torch.exp(x) / C
|
melgan/utils/hparams.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/HarryVolek/PyTorch_Speaker_Verification
|
2 |
+
|
3 |
+
import os
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
def load_hparam_str(hp_str):
|
8 |
+
path = 'temp-restore.yaml'
|
9 |
+
with open(path, 'w') as f:
|
10 |
+
f.write(hp_str)
|
11 |
+
ret = HParam(path)
|
12 |
+
os.remove(path)
|
13 |
+
return ret
|
14 |
+
|
15 |
+
|
16 |
+
def load_hparam(filename):
|
17 |
+
stream = open(filename, 'r')
|
18 |
+
docs = yaml.load_all(stream, Loader=yaml.Loader)
|
19 |
+
hparam_dict = dict()
|
20 |
+
for doc in docs:
|
21 |
+
for k, v in doc.items():
|
22 |
+
hparam_dict[k] = v
|
23 |
+
return hparam_dict
|
24 |
+
|
25 |
+
|
26 |
+
def merge_dict(user, default):
|
27 |
+
if isinstance(user, dict) and isinstance(default, dict):
|
28 |
+
for k, v in default.items():
|
29 |
+
if k not in user:
|
30 |
+
user[k] = v
|
31 |
+
else:
|
32 |
+
user[k] = merge_dict(user[k], v)
|
33 |
+
return user
|
34 |
+
|
35 |
+
|
36 |
+
class Dotdict(dict):
|
37 |
+
"""
|
38 |
+
a dictionary that supports dot notation
|
39 |
+
as well as dictionary access notation
|
40 |
+
usage: d = DotDict() or d = DotDict({'val1':'first'})
|
41 |
+
set attributes: d.val2 = 'second' or d['val2'] = 'second'
|
42 |
+
get attributes: d.val2 or d['val2']
|
43 |
+
"""
|
44 |
+
__getattr__ = dict.__getitem__
|
45 |
+
__setattr__ = dict.__setitem__
|
46 |
+
__delattr__ = dict.__delitem__
|
47 |
+
|
48 |
+
def __init__(self, dct=None):
|
49 |
+
dct = dict() if not dct else dct
|
50 |
+
for key, value in dct.items():
|
51 |
+
if hasattr(value, 'keys'):
|
52 |
+
value = Dotdict(value)
|
53 |
+
self[key] = value
|
54 |
+
|
55 |
+
|
56 |
+
class HParam(Dotdict):
|
57 |
+
|
58 |
+
def __init__(self, file):
|
59 |
+
super(Dotdict, self).__init__()
|
60 |
+
hp_dict = load_hparam(file)
|
61 |
+
hp_dotdict = Dotdict(hp_dict)
|
62 |
+
for k, v in hp_dotdict.items():
|
63 |
+
setattr(self, k, v)
|
64 |
+
|
65 |
+
__getattr__ = Dotdict.__getitem__
|
66 |
+
__setattr__ = Dotdict.__setitem__
|
67 |
+
__delattr__ = Dotdict.__delitem__
|
melgan/utils/plotting.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
matplotlib.use("Agg")
|
3 |
+
import matplotlib.pylab as plt
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def save_figure_to_numpy(fig):
|
8 |
+
# save it to a numpy array.
|
9 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
10 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
11 |
+
data = np.transpose(data, (2, 0, 1))
|
12 |
+
return data
|
13 |
+
|
14 |
+
|
15 |
+
def plot_waveform_to_numpy(waveform):
|
16 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
17 |
+
ax.plot()
|
18 |
+
ax.plot(range(len(waveform)), waveform,
|
19 |
+
linewidth=0.1, alpha=0.7, color='blue')
|
20 |
+
|
21 |
+
plt.xlabel("Samples")
|
22 |
+
plt.ylabel("Amplitude")
|
23 |
+
plt.ylim(-1, 1)
|
24 |
+
plt.tight_layout()
|
25 |
+
|
26 |
+
fig.canvas.draw()
|
27 |
+
data = save_figure_to_numpy(fig)
|
28 |
+
plt.close()
|
29 |
+
return data
|
melgan/utils/stft.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
BSD 3-Clause License
|
3 |
+
|
4 |
+
Copyright (c) 2017, Prem Seetharaman
|
5 |
+
All rights reserved.
|
6 |
+
|
7 |
+
* Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
* Redistributions of source code must retain the above copyright notice,
|
11 |
+
this list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
14 |
+
list of conditions and the following disclaimer in the
|
15 |
+
documentation and/or other materials provided with the distribution.
|
16 |
+
|
17 |
+
* Neither the name of the copyright holder nor the names of its
|
18 |
+
contributors may be used to endorse or promote products derived from this
|
19 |
+
software without specific prior written permission.
|
20 |
+
|
21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
31 |
+
"""
|
32 |
+
|
33 |
+
import torch
|
34 |
+
import numpy as np
|
35 |
+
import torch.nn.functional as F
|
36 |
+
from torch.autograd import Variable
|
37 |
+
from scipy.signal import get_window
|
38 |
+
from librosa.util import pad_center, tiny
|
39 |
+
from .audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression
|
40 |
+
from librosa.filters import mel as librosa_mel_fn
|
41 |
+
|
42 |
+
|
43 |
+
class STFT(torch.nn.Module):
|
44 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
45 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800,
|
46 |
+
window='hann'):
|
47 |
+
super(STFT, self).__init__()
|
48 |
+
self.filter_length = filter_length
|
49 |
+
self.hop_length = hop_length
|
50 |
+
self.win_length = win_length
|
51 |
+
self.window = window
|
52 |
+
self.forward_transform = None
|
53 |
+
scale = self.filter_length / self.hop_length
|
54 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
55 |
+
|
56 |
+
cutoff = int((self.filter_length / 2 + 1))
|
57 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
58 |
+
np.imag(fourier_basis[:cutoff, :])])
|
59 |
+
|
60 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
61 |
+
inverse_basis = torch.FloatTensor(
|
62 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
63 |
+
|
64 |
+
if window is not None:
|
65 |
+
assert(filter_length >= win_length)
|
66 |
+
# get window and zero center pad it to filter_length
|
67 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
68 |
+
fft_window = pad_center(fft_window, filter_length)
|
69 |
+
fft_window = torch.from_numpy(fft_window).float()
|
70 |
+
|
71 |
+
# window the bases
|
72 |
+
forward_basis *= fft_window
|
73 |
+
inverse_basis *= fft_window
|
74 |
+
|
75 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
76 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
77 |
+
|
78 |
+
def transform(self, input_data):
|
79 |
+
num_batches = input_data.size(0)
|
80 |
+
num_samples = input_data.size(1)
|
81 |
+
|
82 |
+
self.num_samples = num_samples
|
83 |
+
|
84 |
+
# similar to librosa, reflect-pad the input
|
85 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
86 |
+
input_data = F.pad(
|
87 |
+
input_data.unsqueeze(1),
|
88 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
89 |
+
mode='reflect')
|
90 |
+
input_data = input_data.squeeze(1)
|
91 |
+
|
92 |
+
# https://github.com/NVIDIA/tacotron2/issues/125
|
93 |
+
forward_transform = F.conv1d(
|
94 |
+
input_data.cuda(),
|
95 |
+
Variable(self.forward_basis, requires_grad=False).cuda(),
|
96 |
+
stride=self.hop_length,
|
97 |
+
padding=0).cpu()
|
98 |
+
|
99 |
+
cutoff = int((self.filter_length / 2) + 1)
|
100 |
+
real_part = forward_transform[:, :cutoff, :]
|
101 |
+
imag_part = forward_transform[:, cutoff:, :]
|
102 |
+
|
103 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
104 |
+
phase = torch.autograd.Variable(
|
105 |
+
torch.atan2(imag_part.data, real_part.data))
|
106 |
+
|
107 |
+
return magnitude, phase
|
108 |
+
|
109 |
+
def inverse(self, magnitude, phase):
|
110 |
+
recombine_magnitude_phase = torch.cat(
|
111 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
112 |
+
|
113 |
+
inverse_transform = F.conv_transpose1d(
|
114 |
+
recombine_magnitude_phase,
|
115 |
+
Variable(self.inverse_basis, requires_grad=False),
|
116 |
+
stride=self.hop_length,
|
117 |
+
padding=0)
|
118 |
+
|
119 |
+
if self.window is not None:
|
120 |
+
window_sum = window_sumsquare(
|
121 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
122 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
123 |
+
dtype=np.float32)
|
124 |
+
# remove modulation effects
|
125 |
+
approx_nonzero_indices = torch.from_numpy(
|
126 |
+
np.where(window_sum > tiny(window_sum))[0])
|
127 |
+
window_sum = torch.autograd.Variable(
|
128 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
129 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
130 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
131 |
+
|
132 |
+
# scale by hop ratio
|
133 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
134 |
+
|
135 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
136 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
137 |
+
|
138 |
+
return inverse_transform
|
139 |
+
|
140 |
+
def forward(self, input_data):
|
141 |
+
self.magnitude, self.phase = self.transform(input_data)
|
142 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
143 |
+
return reconstruction
|
144 |
+
|
145 |
+
|
146 |
+
class TacotronSTFT(torch.nn.Module):
|
147 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
148 |
+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
149 |
+
mel_fmax=None):
|
150 |
+
super(TacotronSTFT, self).__init__()
|
151 |
+
self.n_mel_channels = n_mel_channels
|
152 |
+
self.sampling_rate = sampling_rate
|
153 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
154 |
+
mel_basis = librosa_mel_fn(
|
155 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
156 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
157 |
+
self.register_buffer('mel_basis', mel_basis)
|
158 |
+
|
159 |
+
def spectral_normalize(self, magnitudes):
|
160 |
+
output = dynamic_range_compression(magnitudes)
|
161 |
+
return output
|
162 |
+
|
163 |
+
def spectral_de_normalize(self, magnitudes):
|
164 |
+
output = dynamic_range_decompression(magnitudes)
|
165 |
+
return output
|
166 |
+
|
167 |
+
def mel_spectrogram(self, y):
|
168 |
+
"""Computes mel-spectrograms from a batch of waves
|
169 |
+
PARAMS
|
170 |
+
------
|
171 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
172 |
+
|
173 |
+
RETURNS
|
174 |
+
-------
|
175 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
176 |
+
"""
|
177 |
+
assert(torch.min(y.data) >= -1)
|
178 |
+
assert(torch.max(y.data) <= 1)
|
179 |
+
|
180 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
181 |
+
magnitudes = magnitudes.data
|
182 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
183 |
+
mel_output = self.spectral_normalize(mel_output)
|
184 |
+
return mel_output
|
melgan/utils/train.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import tqdm
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import itertools
|
8 |
+
import traceback
|
9 |
+
|
10 |
+
from model.generator import Generator
|
11 |
+
from model.multiscale import MultiScaleDiscriminator
|
12 |
+
from .utils import get_commit_hash
|
13 |
+
from .validation import validate
|
14 |
+
|
15 |
+
|
16 |
+
def train(args, pt_dir, chkpt_path, trainloader, valloader, writer, logger, hp, hp_str):
|
17 |
+
model_g = Generator(hp.audio.n_mel_channels).cuda()
|
18 |
+
model_d = MultiScaleDiscriminator().cuda()
|
19 |
+
|
20 |
+
optim_g = torch.optim.Adam(model_g.parameters(),
|
21 |
+
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
|
22 |
+
optim_d = torch.optim.Adam(model_d.parameters(),
|
23 |
+
lr=hp.train.adam.lr, betas=(hp.train.adam.beta1, hp.train.adam.beta2))
|
24 |
+
|
25 |
+
githash = get_commit_hash()
|
26 |
+
|
27 |
+
init_epoch = -1
|
28 |
+
step = 0
|
29 |
+
|
30 |
+
if chkpt_path is not None:
|
31 |
+
logger.info("Resuming from checkpoint: %s" % chkpt_path)
|
32 |
+
checkpoint = torch.load(chkpt_path)
|
33 |
+
model_g.load_state_dict(checkpoint['model_g'])
|
34 |
+
model_d.load_state_dict(checkpoint['model_d'])
|
35 |
+
optim_g.load_state_dict(checkpoint['optim_g'])
|
36 |
+
optim_d.load_state_dict(checkpoint['optim_d'])
|
37 |
+
step = checkpoint['step']
|
38 |
+
init_epoch = checkpoint['epoch']
|
39 |
+
|
40 |
+
if hp_str != checkpoint['hp_str']:
|
41 |
+
logger.warning("New hparams is different from checkpoint. Will use new.")
|
42 |
+
|
43 |
+
if githash != checkpoint['githash']:
|
44 |
+
logger.warning("Code might be different: git hash is different.")
|
45 |
+
logger.warning("%s -> %s" % (checkpoint['githash'], githash))
|
46 |
+
|
47 |
+
else:
|
48 |
+
logger.info("Starting new training run.")
|
49 |
+
|
50 |
+
# this accelerates training when the size of minibatch is always consistent.
|
51 |
+
# if not consistent, it'll horribly slow down.
|
52 |
+
torch.backends.cudnn.benchmark = True
|
53 |
+
|
54 |
+
try:
|
55 |
+
model_g.train()
|
56 |
+
model_d.train()
|
57 |
+
for epoch in itertools.count(init_epoch+1):
|
58 |
+
if epoch % hp.log.validation_interval == 0:
|
59 |
+
with torch.no_grad():
|
60 |
+
validate(hp, args, model_g, model_d, valloader, writer, step)
|
61 |
+
|
62 |
+
trainloader.dataset.shuffle_mapping()
|
63 |
+
loader = tqdm.tqdm(trainloader, desc='Loading train data')
|
64 |
+
for (melG, audioG), (melD, audioD) in loader:
|
65 |
+
melG = melG.cuda()
|
66 |
+
audioG = audioG.cuda()
|
67 |
+
melD = melD.cuda()
|
68 |
+
audioD = audioD.cuda()
|
69 |
+
|
70 |
+
# generator
|
71 |
+
optim_g.zero_grad()
|
72 |
+
fake_audio = model_g(melG)[:, :, :hp.audio.segment_length]
|
73 |
+
disc_fake = model_d(fake_audio)
|
74 |
+
disc_real = model_d(audioG)
|
75 |
+
loss_g = 0.0
|
76 |
+
for (feats_fake, score_fake), (feats_real, _) in zip(disc_fake, disc_real):
|
77 |
+
loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
|
78 |
+
for feat_f, feat_r in zip(feats_fake, feats_real):
|
79 |
+
loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
|
80 |
+
|
81 |
+
loss_g.backward()
|
82 |
+
optim_g.step()
|
83 |
+
|
84 |
+
# discriminator
|
85 |
+
fake_audio = model_g(melD)[:, :, :hp.audio.segment_length]
|
86 |
+
fake_audio = fake_audio.detach()
|
87 |
+
loss_d_sum = 0.0
|
88 |
+
for _ in range(hp.train.rep_discriminator):
|
89 |
+
optim_d.zero_grad()
|
90 |
+
disc_fake = model_d(fake_audio)
|
91 |
+
disc_real = model_d(audioD)
|
92 |
+
loss_d = 0.0
|
93 |
+
for (_, score_fake), (_, score_real) in zip(disc_fake, disc_real):
|
94 |
+
loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
95 |
+
loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
96 |
+
|
97 |
+
loss_d.backward()
|
98 |
+
optim_d.step()
|
99 |
+
loss_d_sum += loss_d
|
100 |
+
|
101 |
+
step += 1
|
102 |
+
# logging
|
103 |
+
loss_g = loss_g.item()
|
104 |
+
loss_d_avg = loss_d_sum / hp.train.rep_discriminator
|
105 |
+
loss_d_avg = loss_d_avg.item()
|
106 |
+
if any([loss_g > 1e8, math.isnan(loss_g), loss_d_avg > 1e8, math.isnan(loss_d_avg)]):
|
107 |
+
logger.error("loss_g %.01f loss_d_avg %.01f at step %d!" % (loss_g, loss_d_avg, step))
|
108 |
+
raise Exception("Loss exploded")
|
109 |
+
|
110 |
+
if step % hp.log.summary_interval == 0:
|
111 |
+
writer.log_training(loss_g, loss_d_avg, step)
|
112 |
+
loader.set_description("g %.04f d %.04f | step %d" % (loss_g, loss_d_avg, step))
|
113 |
+
|
114 |
+
if epoch % hp.log.save_interval == 0:
|
115 |
+
save_path = os.path.join(pt_dir, '%s_%s_%04d.pt'
|
116 |
+
% (args.name, githash, epoch))
|
117 |
+
torch.save({
|
118 |
+
'model_g': model_g.state_dict(),
|
119 |
+
'model_d': model_d.state_dict(),
|
120 |
+
'optim_g': optim_g.state_dict(),
|
121 |
+
'optim_d': optim_d.state_dict(),
|
122 |
+
'step': step,
|
123 |
+
'epoch': epoch,
|
124 |
+
'hp_str': hp_str,
|
125 |
+
'githash': githash,
|
126 |
+
}, save_path)
|
127 |
+
logger.info("Saved checkpoint to: %s" % save_path)
|
128 |
+
|
129 |
+
except Exception as e:
|
130 |
+
logger.info("Exiting due to exception: %s" % e)
|
131 |
+
traceback.print_exc()
|
melgan/utils/utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import subprocess
|
3 |
+
import numpy as np
|
4 |
+
from scipy.io.wavfile import read
|
5 |
+
|
6 |
+
|
7 |
+
def get_commit_hash():
|
8 |
+
message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
|
9 |
+
return message.strip().decode('utf-8')
|
10 |
+
|
11 |
+
def read_wav_np(path):
|
12 |
+
sr, wav = read(path)
|
13 |
+
|
14 |
+
if len(wav.shape) == 2:
|
15 |
+
wav = wav[:, 0]
|
16 |
+
|
17 |
+
if wav.dtype == np.int16:
|
18 |
+
wav = wav / 32768.0
|
19 |
+
elif wav.dtype == np.int32:
|
20 |
+
wav = wav / 2147483648.0
|
21 |
+
elif wav.dtype == np.uint8:
|
22 |
+
wav = (wav - 128) / 128.0
|
23 |
+
|
24 |
+
wav = wav.astype(np.float32)
|
25 |
+
|
26 |
+
return sr, wav
|
melgan/utils/validation.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def validate(hp, args, generator, discriminator, valloader, writer, step):
|
6 |
+
generator.eval()
|
7 |
+
discriminator.eval()
|
8 |
+
torch.backends.cudnn.benchmark = False
|
9 |
+
|
10 |
+
loader = tqdm.tqdm(valloader, desc='Validation loop')
|
11 |
+
loss_g_sum = 0.0
|
12 |
+
loss_d_sum = 0.0
|
13 |
+
for mel, audio in loader:
|
14 |
+
mel = mel.cuda()
|
15 |
+
audio = audio.cuda()
|
16 |
+
|
17 |
+
# generator
|
18 |
+
fake_audio = generator(mel)
|
19 |
+
disc_fake = discriminator(fake_audio[:, :, :audio.size(2)])
|
20 |
+
disc_real = discriminator(audio)
|
21 |
+
loss_g = 0.0
|
22 |
+
loss_d = 0.0
|
23 |
+
for (feats_fake, score_fake), (feats_real, score_real) in zip(disc_fake, disc_real):
|
24 |
+
loss_g += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2]))
|
25 |
+
for feat_f, feat_r in zip(feats_fake, feats_real):
|
26 |
+
loss_g += hp.model.feat_match * torch.mean(torch.abs(feat_f - feat_r))
|
27 |
+
loss_d += torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
28 |
+
loss_d += torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
29 |
+
|
30 |
+
loss_g_sum += loss_g.item()
|
31 |
+
loss_d_sum += loss_d.item()
|
32 |
+
|
33 |
+
loss_g_avg = loss_g_sum / len(valloader.dataset)
|
34 |
+
loss_d_avg = loss_d_sum / len(valloader.dataset)
|
35 |
+
|
36 |
+
audio = audio[0][0].cpu().detach().numpy()
|
37 |
+
fake_audio = fake_audio[0][0].cpu().detach().numpy()
|
38 |
+
|
39 |
+
writer.log_validation(loss_g_avg, loss_d_avg, generator, discriminator, audio, fake_audio, step)
|
40 |
+
|
41 |
+
torch.backends.cudnn.benchmark = True
|
melgan/utils/writer.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tensorboardX import SummaryWriter
|
2 |
+
|
3 |
+
from .plotting import plot_waveform_to_numpy
|
4 |
+
|
5 |
+
|
6 |
+
class MyWriter(SummaryWriter):
|
7 |
+
def __init__(self, hp, logdir):
|
8 |
+
super(MyWriter, self).__init__(logdir)
|
9 |
+
self.sample_rate = hp.audio.sampling_rate
|
10 |
+
self.is_first = True
|
11 |
+
|
12 |
+
def log_training(self, g_loss, d_loss, step):
|
13 |
+
self.add_scalar('train.g_loss', g_loss, step)
|
14 |
+
self.add_scalar('train.d_loss', d_loss, step)
|
15 |
+
|
16 |
+
def log_validation(self, g_loss, d_loss, generator, discriminator, target, prediction, step):
|
17 |
+
self.add_scalar('validation.g_loss', g_loss, step)
|
18 |
+
self.add_scalar('validation.d_loss', d_loss, step)
|
19 |
+
|
20 |
+
self.add_audio('raw_audio_predicted', prediction, step, self.sample_rate)
|
21 |
+
self.add_image('waveform_predicted', plot_waveform_to_numpy(prediction), step)
|
22 |
+
|
23 |
+
self.log_histogram(generator, step)
|
24 |
+
self.log_histogram(discriminator, step)
|
25 |
+
|
26 |
+
if self.is_first:
|
27 |
+
self.add_audio('raw_audio_target', target, step, self.sample_rate)
|
28 |
+
self.add_image('waveform_target', plot_waveform_to_numpy(target), step)
|
29 |
+
self.is_first = False
|
30 |
+
|
31 |
+
def log_histogram(self, model, step):
|
32 |
+
for tag, value in model.named_parameters():
|
33 |
+
self.add_histogram(tag.replace('.', '/'), value.cpu().detach().numpy(), step)
|