Spaces:
Build error
Build error
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- output/audio.wav +0 -0
- samples/1320_00000.mp3 +0 -0
- samples/3575_00000.mp3 +0 -0
- samples/6829_00000.mp3 +0 -0
- samples/8230_00000.mp3 +0 -0
- samples/README.md +22 -0
- samples/VCTK.txt +94 -0
- samples/p240_00000.mp3 +0 -0
- samples/p260_00000.mp3 +0 -0
- saved_models/default/encoder.pt +3 -0
- saved_models/default/synthesizer.pt +3 -0
- saved_models/default/vocoder.pt +3 -0
- saved_models/default/zh_synthesizer.pt +3 -0
- synthesizer/LICENSE.txt +24 -0
- synthesizer/__init__.py +1 -0
- synthesizer/__pycache__/__init__.cpython-37.pyc +0 -0
- synthesizer/__pycache__/audio.cpython-37.pyc +0 -0
- synthesizer/__pycache__/hparams.cpython-37.pyc +0 -0
- synthesizer/__pycache__/inference.cpython-37.pyc +0 -0
- synthesizer/audio.py +206 -0
- synthesizer/hparams.py +92 -0
- synthesizer/inference.py +165 -0
- synthesizer/models/__pycache__/tacotron.cpython-37.pyc +0 -0
- synthesizer/models/tacotron.py +519 -0
- synthesizer/preprocess.py +258 -0
- synthesizer/synthesize.py +92 -0
- synthesizer/synthesizer_dataset.py +92 -0
- synthesizer/train.py +258 -0
- synthesizer/utils/__init__.py +45 -0
- synthesizer/utils/__pycache__/__init__.cpython-37.pyc +0 -0
- synthesizer/utils/__pycache__/cleaners.cpython-37.pyc +0 -0
- synthesizer/utils/__pycache__/numbers.cpython-37.pyc +0 -0
- synthesizer/utils/__pycache__/symbols.cpython-37.pyc +0 -0
- synthesizer/utils/__pycache__/text.cpython-37.pyc +0 -0
- synthesizer/utils/_cmudict.py +62 -0
- synthesizer/utils/cleaners.py +88 -0
- synthesizer/utils/numbers.py +69 -0
- synthesizer/utils/plot.py +82 -0
- synthesizer/utils/symbols.py +21 -0
- synthesizer/utils/text.py +75 -0
- toolbox/__init__.py +347 -0
- toolbox/ui.py +607 -0
- toolbox/utterance.py +5 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/__pycache__/argutils.cpython-37.pyc +0 -0
- utils/__pycache__/default_models.cpython-37.pyc +0 -0
- utils/argutils.py +40 -0
- utils/default_models.py +56 -0
- utils/logmmse.py +247 -0
output/audio.wav
ADDED
Binary file (189 kB). View file
|
|
samples/1320_00000.mp3
ADDED
Binary file (15.5 kB). View file
|
|
samples/3575_00000.mp3
ADDED
Binary file (15.5 kB). View file
|
|
samples/6829_00000.mp3
ADDED
Binary file (15.6 kB). View file
|
|
samples/8230_00000.mp3
ADDED
Binary file (16.1 kB). View file
|
|
samples/README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The audio files in this folder are provided for toolbox testing and
|
2 |
+
benchmarking purposes. These are the same reference utterances
|
3 |
+
used by the SV2TTS authors to generate the audio samples located at:
|
4 |
+
https://google.github.io/tacotron/publications/speaker_adaptation/index.html
|
5 |
+
|
6 |
+
The `p240_00000.mp3` and `p260_00000.mp3` files are compressed
|
7 |
+
versions of audios from the VCTK corpus available at:
|
8 |
+
https://datashare.is.ed.ac.uk/handle/10283/3443
|
9 |
+
VCTK.txt contains the copyright notices and licensing information.
|
10 |
+
|
11 |
+
The `1320_00000.mp3`, `3575_00000.mp3`, `6829_00000.mp3`
|
12 |
+
and `8230_00000.mp3` files are compressed versions of audios
|
13 |
+
from the LibriSpeech dataset available at: https://openslr.org/12
|
14 |
+
For these files, the following notice applies:
|
15 |
+
```
|
16 |
+
LibriSpeech (c) 2014 by Vassil Panayotov
|
17 |
+
|
18 |
+
LibriSpeech ASR corpus is licensed under a
|
19 |
+
Creative Commons Attribution 4.0 International License.
|
20 |
+
|
21 |
+
See <http://creativecommons.org/licenses/by/4.0/>.
|
22 |
+
```
|
samples/VCTK.txt
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---------------------------------------------------------------------
|
2 |
+
CSTR VCTK Corpus
|
3 |
+
English Multi-speaker Corpus for CSTR Voice Cloning Toolkit
|
4 |
+
|
5 |
+
(Version 0.92)
|
6 |
+
RELEASE September 2019
|
7 |
+
The Centre for Speech Technology Research
|
8 |
+
University of Edinburgh
|
9 |
+
Copyright (c) 2019
|
10 |
+
|
11 |
+
Junichi Yamagishi
|
12 | |
13 |
+
---------------------------------------------------------------------
|
14 |
+
|
15 |
+
Overview
|
16 |
+
|
17 |
+
This CSTR VCTK Corpus includes speech data uttered by 110 English
|
18 |
+
speakers with various accents. Each speaker reads out about 400
|
19 |
+
sentences, which were selected from a newspaper, the rainbow passage
|
20 |
+
and an elicitation paragraph used for the speech accent archive.
|
21 |
+
|
22 |
+
The newspaper texts were taken from Herald Glasgow, with permission
|
23 |
+
from Herald & Times Group. Each speaker has a different set of the
|
24 |
+
newspaper texts selected based a greedy algorithm that increases the
|
25 |
+
contextual and phonetic coverage. The details of the text selection
|
26 |
+
algorithms are described in the following paper:
|
27 |
+
|
28 |
+
C. Veaux, J. Yamagishi and S. King,
|
29 |
+
"The voice bank corpus: Design, collection and data analysis of
|
30 |
+
a large regional accent speech database,"
|
31 |
+
https://doi.org/10.1109/ICSDA.2013.6709856
|
32 |
+
|
33 |
+
The rainbow passage and elicitation paragraph are the same for all
|
34 |
+
speakers. The rainbow passage can be found at International Dialects
|
35 |
+
of English Archive:
|
36 |
+
(http://web.ku.edu/~idea/readings/rainbow.htm). The elicitation
|
37 |
+
paragraph is identical to the one used for the speech accent archive
|
38 |
+
(http://accent.gmu.edu). The details of the the speech accent archive
|
39 |
+
can be found at
|
40 |
+
http://www.ualberta.ca/~aacl2009/PDFs/WeinbergerKunath2009AACL.pdf
|
41 |
+
|
42 |
+
All speech data was recorded using an identical recording setup: an
|
43 |
+
omni-directional microphone (DPA 4035) and a small diaphragm condenser
|
44 |
+
microphone with very wide bandwidth (Sennheiser MKH 800), 96kHz
|
45 |
+
sampling frequency at 24 bits and in a hemi-anechoic chamber of
|
46 |
+
the University of Edinburgh. (However, two speakers, p280 and p315
|
47 |
+
had technical issues of the audio recordings using MKH 800).
|
48 |
+
All recordings were converted into 16 bits, were downsampled to
|
49 |
+
48 kHz, and were manually end-pointed.
|
50 |
+
|
51 |
+
This corpus was originally aimed for HMM-based text-to-speech synthesis
|
52 |
+
systems, especially for speaker-adaptive HMM-based speech synthesis
|
53 |
+
that uses average voice models trained on multiple speakers and speaker
|
54 |
+
adaptation technologies. This corpus is also suitable for DNN-based
|
55 |
+
multi-speaker text-to-speech synthesis systems and waveform modeling.
|
56 |
+
|
57 |
+
COPYING
|
58 |
+
|
59 |
+
This corpus is licensed under the Creative Commons License: Attribution 4.0 International
|
60 |
+
http://creativecommons.org/licenses/by/4.0/legalcode
|
61 |
+
|
62 |
+
VCTK VARIANTS
|
63 |
+
There are several variants of the VCTK corpus:
|
64 |
+
Speech enhancement
|
65 |
+
- Noisy speech database for training speech enhancement algorithms and TTS models where we added various types of noises to VCTK artificially: http://dx.doi.org/10.7488/ds/2117
|
66 |
+
- Reverberant speech database for training speech dereverberation algorithms and TTS models where we added various types of reverberantion to VCTK artificially http://dx.doi.org/10.7488/ds/1425
|
67 |
+
- Noisy reverberant speech database for training speech enhancement algorithms and TTS models http://dx.doi.org/10.7488/ds/2139
|
68 |
+
- Device Recorded VCTK where speech signals of the VCTK corpus were played back and re-recorded in office environments using relatively inexpensive consumer devices http://dx.doi.org/10.7488/ds/2316
|
69 |
+
- The Microsoft Scalable Noisy Speech Dataset (MS-SNSD) https://github.com/microsoft/MS-SNSD
|
70 |
+
|
71 |
+
ASV and anti-spoofing
|
72 |
+
- Spoofing and Anti-Spoofing (SAS) corpus, which is a collection of synthetic speech signals produced by nine techniques, two of which are speech synthesis, and seven are voice conversion. All of them were built using the VCTK corpus. http://dx.doi.org/10.7488/ds/252
|
73 |
+
- Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) Database. This database consists of synthetic speech signals produced by ten techniques and this has been used in the first Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) http://dx.doi.org/10.7488/ds/298
|
74 |
+
- ASVspoof 2019: The 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge database. This database has been used in the 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2019) https://doi.org/10.7488/ds/2555
|
75 |
+
|
76 |
+
|
77 |
+
ACKNOWLEDGEMENTS
|
78 |
+
|
79 |
+
The CSTR VCTK Corpus was constructed by:
|
80 |
+
|
81 |
+
Christophe Veaux (University of Edinburgh)
|
82 |
+
Junichi Yamagishi (University of Edinburgh)
|
83 |
+
Kirsten MacDonald
|
84 |
+
|
85 |
+
The research leading to these results was partly funded from EPSRC
|
86 |
+
grants EP/I031022/1 (NST) and EP/J002526/1 (CAF), from the RSE-NSFC
|
87 |
+
grant (61111130120), and from the JST CREST (uDialogue).
|
88 |
+
|
89 |
+
Please cite this corpus as follows:
|
90 |
+
Christophe Veaux, Junichi Yamagishi, Kirsten MacDonald,
|
91 |
+
"CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit",
|
92 |
+
The Centre for Speech Technology Research (CSTR),
|
93 |
+
University of Edinburgh
|
94 |
+
|
samples/p240_00000.mp3
ADDED
Binary file (20.2 kB). View file
|
|
samples/p260_00000.mp3
ADDED
Binary file (20.5 kB). View file
|
|
saved_models/default/encoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39373b86598fa3da9fcddee6142382efe09777e8d37dc9c0561f41f0070f134e
|
3 |
+
size 17090379
|
saved_models/default/synthesizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c05e07428f95d0ed8755e1ef54cc8ae251300413d94ce5867a56afe39c499d94
|
3 |
+
size 370554559
|
saved_models/default/vocoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169
|
3 |
+
size 53845290
|
saved_models/default/zh_synthesizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:27de1bfd98fe7f99f99399c0349b35e213673d8181412deb914bc5593460dfb2
|
3 |
+
size 370667477
|
synthesizer/LICENSE.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
|
4 |
+
Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
5 |
+
Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
|
6 |
+
Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
|
7 |
+
|
8 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
of this software and associated documentation files (the "Software"), to deal
|
10 |
+
in the Software without restriction, including without limitation the rights
|
11 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
copies of the Software, and to permit persons to whom the Software is
|
13 |
+
furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
The above copyright notice and this permission notice shall be included in all
|
16 |
+
copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
SOFTWARE.
|
synthesizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
synthesizer/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (163 Bytes). View file
|
|
synthesizer/__pycache__/audio.cpython-37.pyc
ADDED
Binary file (6.74 kB). View file
|
|
synthesizer/__pycache__/hparams.cpython-37.pyc
ADDED
Binary file (2.77 kB). View file
|
|
synthesizer/__pycache__/inference.cpython-37.pyc
ADDED
Binary file (6.3 kB). View file
|
|
synthesizer/audio.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
from scipy import signal
|
5 |
+
from scipy.io import wavfile
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
|
9 |
+
def load_wav(path, sr):
|
10 |
+
return librosa.core.load(path, sr=sr)[0]
|
11 |
+
|
12 |
+
def save_wav(wav, path, sr):
|
13 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
14 |
+
#proposed by @dsmiller
|
15 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
16 |
+
|
17 |
+
def save_wavenet_wav(wav, path, sr):
|
18 |
+
sf.write(path, wav.astype(np.float32), sr)
|
19 |
+
|
20 |
+
def preemphasis(wav, k, preemphasize=True):
|
21 |
+
if preemphasize:
|
22 |
+
return signal.lfilter([1, -k], [1], wav)
|
23 |
+
return wav
|
24 |
+
|
25 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
26 |
+
if inv_preemphasize:
|
27 |
+
return signal.lfilter([1], [1, -k], wav)
|
28 |
+
return wav
|
29 |
+
|
30 |
+
#From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
|
31 |
+
def start_and_end_indices(quantized, silence_threshold=2):
|
32 |
+
for start in range(quantized.size):
|
33 |
+
if abs(quantized[start] - 127) > silence_threshold:
|
34 |
+
break
|
35 |
+
for end in range(quantized.size - 1, 1, -1):
|
36 |
+
if abs(quantized[end] - 127) > silence_threshold:
|
37 |
+
break
|
38 |
+
|
39 |
+
assert abs(quantized[start] - 127) > silence_threshold
|
40 |
+
assert abs(quantized[end] - 127) > silence_threshold
|
41 |
+
|
42 |
+
return start, end
|
43 |
+
|
44 |
+
def get_hop_size(hparams):
|
45 |
+
hop_size = hparams.hop_size
|
46 |
+
if hop_size is None:
|
47 |
+
assert hparams.frame_shift_ms is not None
|
48 |
+
hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
|
49 |
+
return hop_size
|
50 |
+
|
51 |
+
def linearspectrogram(wav, hparams):
|
52 |
+
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
53 |
+
S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
|
54 |
+
|
55 |
+
if hparams.signal_normalization:
|
56 |
+
return _normalize(S, hparams)
|
57 |
+
return S
|
58 |
+
|
59 |
+
def melspectrogram(wav, hparams):
|
60 |
+
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
|
61 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
|
62 |
+
|
63 |
+
if hparams.signal_normalization:
|
64 |
+
return _normalize(S, hparams)
|
65 |
+
return S
|
66 |
+
|
67 |
+
def inv_linear_spectrogram(linear_spectrogram, hparams):
|
68 |
+
"""Converts linear spectrogram to waveform using librosa"""
|
69 |
+
if hparams.signal_normalization:
|
70 |
+
D = _denormalize(linear_spectrogram, hparams)
|
71 |
+
else:
|
72 |
+
D = linear_spectrogram
|
73 |
+
|
74 |
+
S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
|
75 |
+
|
76 |
+
if hparams.use_lws:
|
77 |
+
processor = _lws_processor(hparams)
|
78 |
+
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
79 |
+
y = processor.istft(D).astype(np.float32)
|
80 |
+
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
81 |
+
else:
|
82 |
+
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
83 |
+
|
84 |
+
def inv_mel_spectrogram(mel_spectrogram, hparams):
|
85 |
+
"""Converts mel spectrogram to waveform using librosa"""
|
86 |
+
if hparams.signal_normalization:
|
87 |
+
D = _denormalize(mel_spectrogram, hparams)
|
88 |
+
else:
|
89 |
+
D = mel_spectrogram
|
90 |
+
|
91 |
+
S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
|
92 |
+
|
93 |
+
if hparams.use_lws:
|
94 |
+
processor = _lws_processor(hparams)
|
95 |
+
D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
|
96 |
+
y = processor.istft(D).astype(np.float32)
|
97 |
+
return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
|
98 |
+
else:
|
99 |
+
return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
|
100 |
+
|
101 |
+
def _lws_processor(hparams):
|
102 |
+
import lws
|
103 |
+
return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
|
104 |
+
|
105 |
+
def _griffin_lim(S, hparams):
|
106 |
+
"""librosa implementation of Griffin-Lim
|
107 |
+
Based on https://github.com/librosa/librosa/issues/434
|
108 |
+
"""
|
109 |
+
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
110 |
+
S_complex = np.abs(S).astype(np.complex)
|
111 |
+
y = _istft(S_complex * angles, hparams)
|
112 |
+
for i in range(hparams.griffin_lim_iters):
|
113 |
+
angles = np.exp(1j * np.angle(_stft(y, hparams)))
|
114 |
+
y = _istft(S_complex * angles, hparams)
|
115 |
+
return y
|
116 |
+
|
117 |
+
def _stft(y, hparams):
|
118 |
+
if hparams.use_lws:
|
119 |
+
return _lws_processor(hparams).stft(y).T
|
120 |
+
else:
|
121 |
+
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
122 |
+
|
123 |
+
def _istft(y, hparams):
|
124 |
+
return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
|
125 |
+
|
126 |
+
##########################################################
|
127 |
+
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
128 |
+
def num_frames(length, fsize, fshift):
|
129 |
+
"""Compute number of time frames of spectrogram
|
130 |
+
"""
|
131 |
+
pad = (fsize - fshift)
|
132 |
+
if length % fshift == 0:
|
133 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
134 |
+
else:
|
135 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
136 |
+
return M
|
137 |
+
|
138 |
+
|
139 |
+
def pad_lr(x, fsize, fshift):
|
140 |
+
"""Compute left and right padding
|
141 |
+
"""
|
142 |
+
M = num_frames(len(x), fsize, fshift)
|
143 |
+
pad = (fsize - fshift)
|
144 |
+
T = len(x) + 2 * pad
|
145 |
+
r = (M - 1) * fshift + fsize - T
|
146 |
+
return pad, pad + r
|
147 |
+
##########################################################
|
148 |
+
#Librosa correct padding
|
149 |
+
def librosa_pad_lr(x, fsize, fshift):
|
150 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
151 |
+
|
152 |
+
# Conversions
|
153 |
+
_mel_basis = None
|
154 |
+
_inv_mel_basis = None
|
155 |
+
|
156 |
+
def _linear_to_mel(spectogram, hparams):
|
157 |
+
global _mel_basis
|
158 |
+
if _mel_basis is None:
|
159 |
+
_mel_basis = _build_mel_basis(hparams)
|
160 |
+
return np.dot(_mel_basis, spectogram)
|
161 |
+
|
162 |
+
def _mel_to_linear(mel_spectrogram, hparams):
|
163 |
+
global _inv_mel_basis
|
164 |
+
if _inv_mel_basis is None:
|
165 |
+
_inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
|
166 |
+
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
167 |
+
|
168 |
+
def _build_mel_basis(hparams):
|
169 |
+
assert hparams.fmax <= hparams.sample_rate // 2
|
170 |
+
return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
|
171 |
+
fmin=hparams.fmin, fmax=hparams.fmax)
|
172 |
+
|
173 |
+
def _amp_to_db(x, hparams):
|
174 |
+
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
|
175 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
176 |
+
|
177 |
+
def _db_to_amp(x):
|
178 |
+
return np.power(10.0, (x) * 0.05)
|
179 |
+
|
180 |
+
def _normalize(S, hparams):
|
181 |
+
if hparams.allow_clipping_in_normalization:
|
182 |
+
if hparams.symmetric_mels:
|
183 |
+
return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
|
184 |
+
-hparams.max_abs_value, hparams.max_abs_value)
|
185 |
+
else:
|
186 |
+
return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
|
187 |
+
|
188 |
+
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
|
189 |
+
if hparams.symmetric_mels:
|
190 |
+
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
|
191 |
+
else:
|
192 |
+
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
|
193 |
+
|
194 |
+
def _denormalize(D, hparams):
|
195 |
+
if hparams.allow_clipping_in_normalization:
|
196 |
+
if hparams.symmetric_mels:
|
197 |
+
return (((np.clip(D, -hparams.max_abs_value,
|
198 |
+
hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
|
199 |
+
+ hparams.min_level_db)
|
200 |
+
else:
|
201 |
+
return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
202 |
+
|
203 |
+
if hparams.symmetric_mels:
|
204 |
+
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
|
205 |
+
else:
|
206 |
+
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
|
synthesizer/hparams.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import pprint
|
3 |
+
|
4 |
+
class HParams(object):
|
5 |
+
def __init__(self, **kwargs): self.__dict__.update(kwargs)
|
6 |
+
def __setitem__(self, key, value): setattr(self, key, value)
|
7 |
+
def __getitem__(self, key): return getattr(self, key)
|
8 |
+
def __repr__(self): return pprint.pformat(self.__dict__)
|
9 |
+
|
10 |
+
def parse(self, string):
|
11 |
+
# Overrides hparams from a comma-separated string of name=value pairs
|
12 |
+
if len(string) > 0:
|
13 |
+
overrides = [s.split("=") for s in string.split(",")]
|
14 |
+
keys, values = zip(*overrides)
|
15 |
+
keys = list(map(str.strip, keys))
|
16 |
+
values = list(map(str.strip, values))
|
17 |
+
for k in keys:
|
18 |
+
self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
|
19 |
+
return self
|
20 |
+
|
21 |
+
hparams = HParams(
|
22 |
+
### Signal Processing (used in both synthesizer and vocoder)
|
23 |
+
sample_rate = 16000,
|
24 |
+
n_fft = 800,
|
25 |
+
num_mels = 80,
|
26 |
+
hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
|
27 |
+
win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
|
28 |
+
fmin = 55,
|
29 |
+
min_level_db = -100,
|
30 |
+
ref_level_db = 20,
|
31 |
+
max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
|
32 |
+
preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
|
33 |
+
preemphasize = True,
|
34 |
+
|
35 |
+
### Tacotron Text-to-Speech (TTS)
|
36 |
+
tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
|
37 |
+
tts_encoder_dims = 256,
|
38 |
+
tts_decoder_dims = 128,
|
39 |
+
tts_postnet_dims = 512,
|
40 |
+
tts_encoder_K = 5,
|
41 |
+
tts_lstm_dims = 1024,
|
42 |
+
tts_postnet_K = 5,
|
43 |
+
tts_num_highways = 4,
|
44 |
+
tts_dropout = 0.5,
|
45 |
+
tts_cleaner_names = ["english_cleaners"],
|
46 |
+
tts_stop_threshold = -3.4, # Value below which audio generation ends.
|
47 |
+
# For example, for a range of [-4, 4], this
|
48 |
+
# will terminate the sequence at the first
|
49 |
+
# frame that has all values < -3.4
|
50 |
+
|
51 |
+
### Tacotron Training
|
52 |
+
tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
|
53 |
+
(2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
|
54 |
+
(2, 2e-4, 80_000, 12), #
|
55 |
+
(2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
|
56 |
+
(2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
|
57 |
+
(2, 1e-5, 640_000, 12)], # lr = learning rate
|
58 |
+
|
59 |
+
tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
|
60 |
+
tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
|
61 |
+
# Set to -1 to generate after completing epoch, or 0 to disable
|
62 |
+
|
63 |
+
tts_eval_num_samples = 1, # Makes this number of samples
|
64 |
+
|
65 |
+
### Data Preprocessing
|
66 |
+
max_mel_frames = 900,
|
67 |
+
rescale = True,
|
68 |
+
rescaling_max = 0.9,
|
69 |
+
synthesis_batch_size = 16, # For vocoder preprocessing and inference.
|
70 |
+
|
71 |
+
### Mel Visualization and Griffin-Lim
|
72 |
+
signal_normalization = True,
|
73 |
+
power = 1.5,
|
74 |
+
griffin_lim_iters = 60,
|
75 |
+
|
76 |
+
### Audio processing options
|
77 |
+
fmax = 7600, # Should not exceed (sample_rate // 2)
|
78 |
+
allow_clipping_in_normalization = True, # Used when signal_normalization = True
|
79 |
+
clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
|
80 |
+
use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
|
81 |
+
symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
|
82 |
+
# and [0, max_abs_value] if False
|
83 |
+
trim_silence = True, # Use with sample_rate of 16000 for best results
|
84 |
+
|
85 |
+
### SV2TTS
|
86 |
+
speaker_embedding_size = 256, # Dimension for the speaker embedding
|
87 |
+
silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
|
88 |
+
utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
|
89 |
+
)
|
90 |
+
|
91 |
+
def hparams_debug_string():
|
92 |
+
return str(hparams)
|
synthesizer/inference.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from synthesizer import audio
|
3 |
+
from synthesizer.hparams import hparams
|
4 |
+
from synthesizer.models.tacotron import Tacotron
|
5 |
+
from synthesizer.utils.symbols import symbols
|
6 |
+
from synthesizer.utils.text import text_to_sequence
|
7 |
+
from vocoder.display import simple_table
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Union, List
|
10 |
+
import numpy as np
|
11 |
+
import librosa
|
12 |
+
|
13 |
+
|
14 |
+
class Synthesizer:
|
15 |
+
sample_rate = hparams.sample_rate
|
16 |
+
hparams = hparams
|
17 |
+
|
18 |
+
def __init__(self, model_fpath: Path, verbose=True):
|
19 |
+
"""
|
20 |
+
The model isn't instantiated and loaded in memory until needed or until load() is called.
|
21 |
+
|
22 |
+
:param model_fpath: path to the trained model file
|
23 |
+
:param verbose: if False, prints less information when using the model
|
24 |
+
"""
|
25 |
+
self.model_fpath = model_fpath
|
26 |
+
self.verbose = verbose
|
27 |
+
|
28 |
+
# Check for GPU
|
29 |
+
if torch.cuda.is_available():
|
30 |
+
self.device = torch.device("cuda")
|
31 |
+
else:
|
32 |
+
self.device = torch.device("cpu")
|
33 |
+
if self.verbose:
|
34 |
+
print("Synthesizer using device:", self.device)
|
35 |
+
|
36 |
+
# Tacotron model will be instantiated later on first use.
|
37 |
+
self._model = None
|
38 |
+
|
39 |
+
def is_loaded(self):
|
40 |
+
"""
|
41 |
+
Whether the model is loaded in memory.
|
42 |
+
"""
|
43 |
+
return self._model is not None
|
44 |
+
|
45 |
+
def load(self):
|
46 |
+
"""
|
47 |
+
Instantiates and loads the model given the weights file that was passed in the constructor.
|
48 |
+
"""
|
49 |
+
self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
50 |
+
num_chars=len(symbols),
|
51 |
+
encoder_dims=hparams.tts_encoder_dims,
|
52 |
+
decoder_dims=hparams.tts_decoder_dims,
|
53 |
+
n_mels=hparams.num_mels,
|
54 |
+
fft_bins=hparams.num_mels,
|
55 |
+
postnet_dims=hparams.tts_postnet_dims,
|
56 |
+
encoder_K=hparams.tts_encoder_K,
|
57 |
+
lstm_dims=hparams.tts_lstm_dims,
|
58 |
+
postnet_K=hparams.tts_postnet_K,
|
59 |
+
num_highways=hparams.tts_num_highways,
|
60 |
+
dropout=hparams.tts_dropout,
|
61 |
+
stop_threshold=hparams.tts_stop_threshold,
|
62 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
|
63 |
+
|
64 |
+
self._model.load(self.model_fpath)
|
65 |
+
self._model.eval()
|
66 |
+
|
67 |
+
if self.verbose:
|
68 |
+
print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
|
69 |
+
|
70 |
+
def synthesize_spectrograms(self, texts: List[str],
|
71 |
+
embeddings: Union[np.ndarray, List[np.ndarray]],
|
72 |
+
return_alignments=False):
|
73 |
+
"""
|
74 |
+
Synthesizes mel spectrograms from texts and speaker embeddings.
|
75 |
+
|
76 |
+
:param texts: a list of N text prompts to be synthesized
|
77 |
+
:param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
|
78 |
+
:param return_alignments: if True, a matrix representing the alignments between the
|
79 |
+
characters
|
80 |
+
and each decoder output step will be returned for each spectrogram
|
81 |
+
:return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
|
82 |
+
sequence length of spectrogram i, and possibly the alignments.
|
83 |
+
"""
|
84 |
+
# Load the model on the first request.
|
85 |
+
if not self.is_loaded():
|
86 |
+
self.load()
|
87 |
+
|
88 |
+
# Preprocess text inputs
|
89 |
+
inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
|
90 |
+
if not isinstance(embeddings, list):
|
91 |
+
embeddings = [embeddings]
|
92 |
+
|
93 |
+
# Batch inputs
|
94 |
+
batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
|
95 |
+
for i in range(0, len(inputs), hparams.synthesis_batch_size)]
|
96 |
+
batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
|
97 |
+
for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
|
98 |
+
|
99 |
+
specs = []
|
100 |
+
for i, batch in enumerate(batched_inputs, 1):
|
101 |
+
if self.verbose:
|
102 |
+
print(f"\n| Generating {i}/{len(batched_inputs)}")
|
103 |
+
|
104 |
+
# Pad texts so they are all the same length
|
105 |
+
text_lens = [len(text) for text in batch]
|
106 |
+
max_text_len = max(text_lens)
|
107 |
+
chars = [pad1d(text, max_text_len) for text in batch]
|
108 |
+
chars = np.stack(chars)
|
109 |
+
|
110 |
+
# Stack speaker embeddings into 2D array for batch processing
|
111 |
+
speaker_embeds = np.stack(batched_embeds[i-1])
|
112 |
+
|
113 |
+
# Convert to tensor
|
114 |
+
chars = torch.tensor(chars).long().to(self.device)
|
115 |
+
speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
|
116 |
+
|
117 |
+
# Inference
|
118 |
+
_, mels, alignments = self._model.generate(chars, speaker_embeddings)
|
119 |
+
mels = mels.detach().cpu().numpy()
|
120 |
+
for m in mels:
|
121 |
+
# Trim silence from end of each spectrogram
|
122 |
+
while np.max(m[:, -1]) < hparams.tts_stop_threshold:
|
123 |
+
m = m[:, :-1]
|
124 |
+
specs.append(m)
|
125 |
+
|
126 |
+
if self.verbose:
|
127 |
+
print("\n\nDone.\n")
|
128 |
+
return (specs, alignments) if return_alignments else specs
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def load_preprocess_wav(fpath):
|
132 |
+
"""
|
133 |
+
Loads and preprocesses an audio file under the same conditions the audio files were used to
|
134 |
+
train the synthesizer.
|
135 |
+
"""
|
136 |
+
wav = librosa.load(str(fpath), hparams.sample_rate)[0]
|
137 |
+
if hparams.rescale:
|
138 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
139 |
+
return wav
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
|
143 |
+
"""
|
144 |
+
Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
|
145 |
+
were fed to the synthesizer when training.
|
146 |
+
"""
|
147 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
148 |
+
wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
|
149 |
+
else:
|
150 |
+
wav = fpath_or_wav
|
151 |
+
|
152 |
+
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
153 |
+
return mel_spectrogram
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def griffin_lim(mel):
|
157 |
+
"""
|
158 |
+
Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
|
159 |
+
with the same parameters present in hparams.py.
|
160 |
+
"""
|
161 |
+
return audio.inv_mel_spectrogram(mel, hparams)
|
162 |
+
|
163 |
+
|
164 |
+
def pad1d(x, max_len, pad_value=0):
|
165 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
synthesizer/models/__pycache__/tacotron.cpython-37.pyc
ADDED
Binary file (14.2 kB). View file
|
|
synthesizer/models/tacotron.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
|
10 |
+
class HighwayNetwork(nn.Module):
|
11 |
+
def __init__(self, size):
|
12 |
+
super().__init__()
|
13 |
+
self.W1 = nn.Linear(size, size)
|
14 |
+
self.W2 = nn.Linear(size, size)
|
15 |
+
self.W1.bias.data.fill_(0.)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
x1 = self.W1(x)
|
19 |
+
x2 = self.W2(x)
|
20 |
+
g = torch.sigmoid(x2)
|
21 |
+
y = g * F.relu(x1) + (1. - g) * x
|
22 |
+
return y
|
23 |
+
|
24 |
+
|
25 |
+
class Encoder(nn.Module):
|
26 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
|
27 |
+
super().__init__()
|
28 |
+
prenet_dims = (encoder_dims, encoder_dims)
|
29 |
+
cbhg_channels = encoder_dims
|
30 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
31 |
+
self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
32 |
+
dropout=dropout)
|
33 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
34 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
35 |
+
num_highways=num_highways)
|
36 |
+
|
37 |
+
def forward(self, x, speaker_embedding=None):
|
38 |
+
x = self.embedding(x)
|
39 |
+
x = self.pre_net(x)
|
40 |
+
x.transpose_(1, 2)
|
41 |
+
x = self.cbhg(x)
|
42 |
+
if speaker_embedding is not None:
|
43 |
+
x = self.add_speaker_embedding(x, speaker_embedding)
|
44 |
+
return x
|
45 |
+
|
46 |
+
def add_speaker_embedding(self, x, speaker_embedding):
|
47 |
+
# SV2TTS
|
48 |
+
# The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
|
49 |
+
# When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
|
50 |
+
# (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
|
51 |
+
# This concats the speaker embedding for each char in the encoder output
|
52 |
+
|
53 |
+
# Save the dimensions as human-readable names
|
54 |
+
batch_size = x.size()[0]
|
55 |
+
num_chars = x.size()[1]
|
56 |
+
|
57 |
+
if speaker_embedding.dim() == 1:
|
58 |
+
idx = 0
|
59 |
+
else:
|
60 |
+
idx = 1
|
61 |
+
|
62 |
+
# Start by making a copy of each speaker embedding to match the input text length
|
63 |
+
# The output of this has size (batch_size, num_chars * tts_embed_dims)
|
64 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
65 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
66 |
+
|
67 |
+
# Reshape it and transpose
|
68 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
69 |
+
e = e.transpose(1, 2)
|
70 |
+
|
71 |
+
# Concatenate the tiled speaker embedding with the encoder output
|
72 |
+
x = torch.cat((x, e), 2)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class BatchNormConv(nn.Module):
|
77 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
78 |
+
super().__init__()
|
79 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
80 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
81 |
+
self.relu = relu
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.conv(x)
|
85 |
+
x = F.relu(x) if self.relu is True else x
|
86 |
+
return self.bnorm(x)
|
87 |
+
|
88 |
+
|
89 |
+
class CBHG(nn.Module):
|
90 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
# List of all rnns to call `flatten_parameters()` on
|
94 |
+
self._to_flatten = []
|
95 |
+
|
96 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
97 |
+
self.conv1d_bank = nn.ModuleList()
|
98 |
+
for k in self.bank_kernels:
|
99 |
+
conv = BatchNormConv(in_channels, channels, k)
|
100 |
+
self.conv1d_bank.append(conv)
|
101 |
+
|
102 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
103 |
+
|
104 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
105 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
106 |
+
|
107 |
+
# Fix the highway input if necessary
|
108 |
+
if proj_channels[-1] != channels:
|
109 |
+
self.highway_mismatch = True
|
110 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
111 |
+
else:
|
112 |
+
self.highway_mismatch = False
|
113 |
+
|
114 |
+
self.highways = nn.ModuleList()
|
115 |
+
for i in range(num_highways):
|
116 |
+
hn = HighwayNetwork(channels)
|
117 |
+
self.highways.append(hn)
|
118 |
+
|
119 |
+
self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
|
120 |
+
self._to_flatten.append(self.rnn)
|
121 |
+
|
122 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
123 |
+
self._flatten_parameters()
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
127 |
+
# the model gets replicated, making it no longer guaranteed that the
|
128 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
129 |
+
self._flatten_parameters()
|
130 |
+
|
131 |
+
# Save these for later
|
132 |
+
residual = x
|
133 |
+
seq_len = x.size(-1)
|
134 |
+
conv_bank = []
|
135 |
+
|
136 |
+
# Convolution Bank
|
137 |
+
for conv in self.conv1d_bank:
|
138 |
+
c = conv(x) # Convolution
|
139 |
+
conv_bank.append(c[:, :, :seq_len])
|
140 |
+
|
141 |
+
# Stack along the channel axis
|
142 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
143 |
+
|
144 |
+
# dump the last padding to fit residual
|
145 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
146 |
+
|
147 |
+
# Conv1d projections
|
148 |
+
x = self.conv_project1(x)
|
149 |
+
x = self.conv_project2(x)
|
150 |
+
|
151 |
+
# Residual Connect
|
152 |
+
x = x + residual
|
153 |
+
|
154 |
+
# Through the highways
|
155 |
+
x = x.transpose(1, 2)
|
156 |
+
if self.highway_mismatch is True:
|
157 |
+
x = self.pre_highway(x)
|
158 |
+
for h in self.highways: x = h(x)
|
159 |
+
|
160 |
+
# And then the RNN
|
161 |
+
x, _ = self.rnn(x)
|
162 |
+
return x
|
163 |
+
|
164 |
+
def _flatten_parameters(self):
|
165 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
166 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
167 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
168 |
+
|
169 |
+
class PreNet(nn.Module):
|
170 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
171 |
+
super().__init__()
|
172 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
173 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
174 |
+
self.p = dropout
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
x = self.fc1(x)
|
178 |
+
x = F.relu(x)
|
179 |
+
x = F.dropout(x, self.p, training=True)
|
180 |
+
x = self.fc2(x)
|
181 |
+
x = F.relu(x)
|
182 |
+
x = F.dropout(x, self.p, training=True)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class Attention(nn.Module):
|
187 |
+
def __init__(self, attn_dims):
|
188 |
+
super().__init__()
|
189 |
+
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
190 |
+
self.v = nn.Linear(attn_dims, 1, bias=False)
|
191 |
+
|
192 |
+
def forward(self, encoder_seq_proj, query, t):
|
193 |
+
|
194 |
+
# print(encoder_seq_proj.shape)
|
195 |
+
# Transform the query vector
|
196 |
+
query_proj = self.W(query).unsqueeze(1)
|
197 |
+
|
198 |
+
# Compute the scores
|
199 |
+
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
200 |
+
scores = F.softmax(u, dim=1)
|
201 |
+
|
202 |
+
return scores.transpose(1, 2)
|
203 |
+
|
204 |
+
|
205 |
+
class LSA(nn.Module):
|
206 |
+
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
207 |
+
super().__init__()
|
208 |
+
self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
|
209 |
+
self.L = nn.Linear(filters, attn_dim, bias=False)
|
210 |
+
self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
|
211 |
+
self.v = nn.Linear(attn_dim, 1, bias=False)
|
212 |
+
self.cumulative = None
|
213 |
+
self.attention = None
|
214 |
+
|
215 |
+
def init_attention(self, encoder_seq_proj):
|
216 |
+
device = next(self.parameters()).device # use same device as parameters
|
217 |
+
b, t, c = encoder_seq_proj.size()
|
218 |
+
self.cumulative = torch.zeros(b, t, device=device)
|
219 |
+
self.attention = torch.zeros(b, t, device=device)
|
220 |
+
|
221 |
+
def forward(self, encoder_seq_proj, query, t, chars):
|
222 |
+
|
223 |
+
if t == 0: self.init_attention(encoder_seq_proj)
|
224 |
+
|
225 |
+
processed_query = self.W(query).unsqueeze(1)
|
226 |
+
|
227 |
+
location = self.cumulative.unsqueeze(1)
|
228 |
+
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
229 |
+
|
230 |
+
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
231 |
+
u = u.squeeze(-1)
|
232 |
+
|
233 |
+
# Mask zero padding chars
|
234 |
+
u = u * (chars != 0).float()
|
235 |
+
|
236 |
+
# Smooth Attention
|
237 |
+
# scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
238 |
+
scores = F.softmax(u, dim=1)
|
239 |
+
self.attention = scores
|
240 |
+
self.cumulative = self.cumulative + self.attention
|
241 |
+
|
242 |
+
return scores.unsqueeze(-1).transpose(1, 2)
|
243 |
+
|
244 |
+
|
245 |
+
class Decoder(nn.Module):
|
246 |
+
# Class variable because its value doesn't change between classes
|
247 |
+
# yet ought to be scoped by class because its a property of a Decoder
|
248 |
+
max_r = 20
|
249 |
+
def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
|
250 |
+
dropout, speaker_embedding_size):
|
251 |
+
super().__init__()
|
252 |
+
self.register_buffer("r", torch.tensor(1, dtype=torch.int))
|
253 |
+
self.n_mels = n_mels
|
254 |
+
prenet_dims = (decoder_dims * 2, decoder_dims * 2)
|
255 |
+
self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
|
256 |
+
dropout=dropout)
|
257 |
+
self.attn_net = LSA(decoder_dims)
|
258 |
+
self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
|
259 |
+
self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
|
260 |
+
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
261 |
+
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
262 |
+
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
263 |
+
self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
|
264 |
+
|
265 |
+
def zoneout(self, prev, current, p=0.1):
|
266 |
+
device = next(self.parameters()).device # Use same device as parameters
|
267 |
+
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
268 |
+
return prev * mask + current * (1 - mask)
|
269 |
+
|
270 |
+
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
271 |
+
hidden_states, cell_states, context_vec, t, chars):
|
272 |
+
|
273 |
+
# Need this for reshaping mels
|
274 |
+
batch_size = encoder_seq.size(0)
|
275 |
+
|
276 |
+
# Unpack the hidden and cell states
|
277 |
+
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
278 |
+
rnn1_cell, rnn2_cell = cell_states
|
279 |
+
|
280 |
+
# PreNet for the Attention RNN
|
281 |
+
prenet_out = self.prenet(prenet_in)
|
282 |
+
|
283 |
+
# Compute the Attention RNN hidden state
|
284 |
+
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
285 |
+
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
286 |
+
|
287 |
+
# Compute the attention scores
|
288 |
+
scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
|
289 |
+
|
290 |
+
# Dot product to create the context vector
|
291 |
+
context_vec = scores @ encoder_seq
|
292 |
+
context_vec = context_vec.squeeze(1)
|
293 |
+
|
294 |
+
# Concat Attention RNN output w. Context Vector & project
|
295 |
+
x = torch.cat([context_vec, attn_hidden], dim=1)
|
296 |
+
x = self.rnn_input(x)
|
297 |
+
|
298 |
+
# Compute first Residual RNN
|
299 |
+
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
300 |
+
if self.training:
|
301 |
+
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
302 |
+
else:
|
303 |
+
rnn1_hidden = rnn1_hidden_next
|
304 |
+
x = x + rnn1_hidden
|
305 |
+
|
306 |
+
# Compute second Residual RNN
|
307 |
+
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
308 |
+
if self.training:
|
309 |
+
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
310 |
+
else:
|
311 |
+
rnn2_hidden = rnn2_hidden_next
|
312 |
+
x = x + rnn2_hidden
|
313 |
+
|
314 |
+
# Project Mels
|
315 |
+
mels = self.mel_proj(x)
|
316 |
+
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
317 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
318 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
319 |
+
|
320 |
+
# Stop token prediction
|
321 |
+
s = torch.cat((x, context_vec), dim=1)
|
322 |
+
s = self.stop_proj(s)
|
323 |
+
stop_tokens = torch.sigmoid(s)
|
324 |
+
|
325 |
+
return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
|
326 |
+
|
327 |
+
|
328 |
+
class Tacotron(nn.Module):
|
329 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
|
330 |
+
fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
|
331 |
+
dropout, stop_threshold, speaker_embedding_size):
|
332 |
+
super().__init__()
|
333 |
+
self.n_mels = n_mels
|
334 |
+
self.lstm_dims = lstm_dims
|
335 |
+
self.encoder_dims = encoder_dims
|
336 |
+
self.decoder_dims = decoder_dims
|
337 |
+
self.speaker_embedding_size = speaker_embedding_size
|
338 |
+
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
339 |
+
encoder_K, num_highways, dropout)
|
340 |
+
self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
|
341 |
+
self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
|
342 |
+
dropout, speaker_embedding_size)
|
343 |
+
self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
|
344 |
+
[postnet_dims, fft_bins], num_highways)
|
345 |
+
self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
|
346 |
+
|
347 |
+
self.init_model()
|
348 |
+
self.num_params()
|
349 |
+
|
350 |
+
self.register_buffer("step", torch.zeros(1, dtype=torch.long))
|
351 |
+
self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
|
352 |
+
|
353 |
+
@property
|
354 |
+
def r(self):
|
355 |
+
return self.decoder.r.item()
|
356 |
+
|
357 |
+
@r.setter
|
358 |
+
def r(self, value):
|
359 |
+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
360 |
+
|
361 |
+
def forward(self, x, m, speaker_embedding):
|
362 |
+
device = next(self.parameters()).device # use same device as parameters
|
363 |
+
|
364 |
+
self.step += 1
|
365 |
+
batch_size, _, steps = m.size()
|
366 |
+
|
367 |
+
# Initialise all hidden states and pack into tuple
|
368 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
369 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
370 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
371 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
372 |
+
|
373 |
+
# Initialise all lstm cell states and pack into tuple
|
374 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
375 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
376 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
377 |
+
|
378 |
+
# <GO> Frame for start of decoder loop
|
379 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
380 |
+
|
381 |
+
# Need an initial context vector
|
382 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
383 |
+
|
384 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
385 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
386 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
387 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
388 |
+
|
389 |
+
# Need a couple of lists for outputs
|
390 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
391 |
+
|
392 |
+
# Run the decoder loop
|
393 |
+
for t in range(0, steps, self.r):
|
394 |
+
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
395 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
396 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
397 |
+
hidden_states, cell_states, context_vec, t, x)
|
398 |
+
mel_outputs.append(mel_frames)
|
399 |
+
attn_scores.append(scores)
|
400 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
401 |
+
|
402 |
+
# Concat the mel outputs into sequence
|
403 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
404 |
+
|
405 |
+
# Post-Process for Linear Spectrograms
|
406 |
+
postnet_out = self.postnet(mel_outputs)
|
407 |
+
linear = self.post_proj(postnet_out)
|
408 |
+
linear = linear.transpose(1, 2)
|
409 |
+
|
410 |
+
# For easy visualisation
|
411 |
+
attn_scores = torch.cat(attn_scores, 1)
|
412 |
+
# attn_scores = attn_scores.cpu().data.numpy()
|
413 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
414 |
+
|
415 |
+
return mel_outputs, linear, attn_scores, stop_outputs
|
416 |
+
|
417 |
+
def generate(self, x, speaker_embedding=None, steps=2000):
|
418 |
+
self.eval()
|
419 |
+
device = next(self.parameters()).device # use same device as parameters
|
420 |
+
|
421 |
+
batch_size, _ = x.size()
|
422 |
+
|
423 |
+
# Need to initialise all hidden states and pack into tuple for tidyness
|
424 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
425 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
426 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
427 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
428 |
+
|
429 |
+
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
430 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
431 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
432 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
433 |
+
|
434 |
+
# Need a <GO> Frame for start of decoder loop
|
435 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
436 |
+
|
437 |
+
# Need an initial context vector
|
438 |
+
context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
|
439 |
+
|
440 |
+
# SV2TTS: Run the encoder with the speaker embedding
|
441 |
+
# The projection avoids unnecessary matmuls in the decoder loop
|
442 |
+
encoder_seq = self.encoder(x, speaker_embedding)
|
443 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
444 |
+
|
445 |
+
# Need a couple of lists for outputs
|
446 |
+
mel_outputs, attn_scores, stop_outputs = [], [], []
|
447 |
+
|
448 |
+
# Run the decoder loop
|
449 |
+
for t in range(0, steps, self.r):
|
450 |
+
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
451 |
+
mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
|
452 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
453 |
+
hidden_states, cell_states, context_vec, t, x)
|
454 |
+
mel_outputs.append(mel_frames)
|
455 |
+
attn_scores.append(scores)
|
456 |
+
stop_outputs.extend([stop_tokens] * self.r)
|
457 |
+
# Stop the loop when all stop tokens in batch exceed threshold
|
458 |
+
if (stop_tokens > 0.5).all() and t > 10: break
|
459 |
+
|
460 |
+
# Concat the mel outputs into sequence
|
461 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
462 |
+
|
463 |
+
# Post-Process for Linear Spectrograms
|
464 |
+
postnet_out = self.postnet(mel_outputs)
|
465 |
+
linear = self.post_proj(postnet_out)
|
466 |
+
|
467 |
+
|
468 |
+
linear = linear.transpose(1, 2)
|
469 |
+
|
470 |
+
# For easy visualisation
|
471 |
+
attn_scores = torch.cat(attn_scores, 1)
|
472 |
+
stop_outputs = torch.cat(stop_outputs, 1)
|
473 |
+
|
474 |
+
self.train()
|
475 |
+
|
476 |
+
return mel_outputs, linear, attn_scores
|
477 |
+
|
478 |
+
def init_model(self):
|
479 |
+
for p in self.parameters():
|
480 |
+
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
481 |
+
|
482 |
+
def get_step(self):
|
483 |
+
return self.step.data.item()
|
484 |
+
|
485 |
+
def reset_step(self):
|
486 |
+
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
487 |
+
self.step = self.step.data.new_tensor(1)
|
488 |
+
|
489 |
+
def log(self, path, msg):
|
490 |
+
with open(path, "a") as f:
|
491 |
+
print(msg, file=f)
|
492 |
+
|
493 |
+
def load(self, path, optimizer=None):
|
494 |
+
# Use device of model params as location for loaded state
|
495 |
+
device = next(self.parameters()).device
|
496 |
+
checkpoint = torch.load(str(path), map_location=device)
|
497 |
+
self.load_state_dict(checkpoint["model_state"])
|
498 |
+
|
499 |
+
if "optimizer_state" in checkpoint and optimizer is not None:
|
500 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
501 |
+
|
502 |
+
def save(self, path, optimizer=None):
|
503 |
+
if optimizer is not None:
|
504 |
+
torch.save({
|
505 |
+
"model_state": self.state_dict(),
|
506 |
+
"optimizer_state": optimizer.state_dict(),
|
507 |
+
}, str(path))
|
508 |
+
else:
|
509 |
+
torch.save({
|
510 |
+
"model_state": self.state_dict(),
|
511 |
+
}, str(path))
|
512 |
+
|
513 |
+
|
514 |
+
def num_params(self, print_out=True):
|
515 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
516 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
517 |
+
if print_out:
|
518 |
+
print("Trainable Parameters: %.3fM" % parameters)
|
519 |
+
return parameters
|
synthesizer/preprocess.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing.pool import Pool
|
2 |
+
from synthesizer import audio
|
3 |
+
from functools import partial
|
4 |
+
from itertools import chain
|
5 |
+
from encoder import inference as encoder
|
6 |
+
from pathlib import Path
|
7 |
+
from utils import logmmse
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import librosa
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
|
14 |
+
no_alignments: bool, datasets_name: str, subfolders: str):
|
15 |
+
# Gather the input directories
|
16 |
+
dataset_root = datasets_root.joinpath(datasets_name)
|
17 |
+
input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
|
18 |
+
print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
|
19 |
+
assert all(input_dir.exists() for input_dir in input_dirs)
|
20 |
+
|
21 |
+
# Create the output directories for each output file type
|
22 |
+
out_dir.joinpath("mels").mkdir(exist_ok=True)
|
23 |
+
out_dir.joinpath("audio").mkdir(exist_ok=True)
|
24 |
+
|
25 |
+
# Create a metadata file
|
26 |
+
metadata_fpath = out_dir.joinpath("train.txt")
|
27 |
+
metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
|
28 |
+
|
29 |
+
# Preprocess the dataset
|
30 |
+
speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
|
31 |
+
func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
|
32 |
+
hparams=hparams, no_alignments=no_alignments)
|
33 |
+
job = Pool(n_processes).imap(func, speaker_dirs)
|
34 |
+
for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
|
35 |
+
for metadatum in speaker_metadata:
|
36 |
+
metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
|
37 |
+
metadata_file.close()
|
38 |
+
|
39 |
+
# Verify the contents of the metadata file
|
40 |
+
with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
|
41 |
+
metadata = [line.split("|") for line in metadata_file]
|
42 |
+
mel_frames = sum([int(m[4]) for m in metadata])
|
43 |
+
timesteps = sum([int(m[3]) for m in metadata])
|
44 |
+
sample_rate = hparams.sample_rate
|
45 |
+
hours = (timesteps / sample_rate) / 3600
|
46 |
+
print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
|
47 |
+
(len(metadata), mel_frames, timesteps, hours))
|
48 |
+
print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
|
49 |
+
print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
|
50 |
+
print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
|
51 |
+
|
52 |
+
|
53 |
+
def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
|
54 |
+
metadata = []
|
55 |
+
for book_dir in speaker_dir.glob("*"):
|
56 |
+
if no_alignments:
|
57 |
+
# Gather the utterance audios and texts
|
58 |
+
# LibriTTS uses .wav but we will include extensions for compatibility with other datasets
|
59 |
+
extensions = ["*.wav", "*.flac", "*.mp3"]
|
60 |
+
for extension in extensions:
|
61 |
+
wav_fpaths = book_dir.glob(extension)
|
62 |
+
|
63 |
+
for wav_fpath in wav_fpaths:
|
64 |
+
# Load the audio waveform
|
65 |
+
wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
|
66 |
+
if hparams.rescale:
|
67 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
68 |
+
|
69 |
+
# Get the corresponding text
|
70 |
+
# Check for .txt (for compatibility with other datasets)
|
71 |
+
text_fpath = wav_fpath.with_suffix(".txt")
|
72 |
+
if not text_fpath.exists():
|
73 |
+
# Check for .normalized.txt (LibriTTS)
|
74 |
+
text_fpath = wav_fpath.with_suffix(".normalized.txt")
|
75 |
+
assert text_fpath.exists()
|
76 |
+
with text_fpath.open("r") as text_file:
|
77 |
+
text = "".join([line for line in text_file])
|
78 |
+
text = text.replace("\"", "")
|
79 |
+
text = text.strip()
|
80 |
+
|
81 |
+
# Process the utterance
|
82 |
+
metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
|
83 |
+
skip_existing, hparams))
|
84 |
+
else:
|
85 |
+
# Process alignment file (LibriSpeech support)
|
86 |
+
# Gather the utterance audios and texts
|
87 |
+
try:
|
88 |
+
alignments_fpath = next(book_dir.glob("*.alignment.txt"))
|
89 |
+
with alignments_fpath.open("r") as alignments_file:
|
90 |
+
alignments = [line.rstrip().split(" ") for line in alignments_file]
|
91 |
+
except StopIteration:
|
92 |
+
# A few alignment files will be missing
|
93 |
+
continue
|
94 |
+
|
95 |
+
# Iterate over each entry in the alignments file
|
96 |
+
for wav_fname, words, end_times in alignments:
|
97 |
+
wav_fpath = book_dir.joinpath(wav_fname + ".flac")
|
98 |
+
assert wav_fpath.exists()
|
99 |
+
words = words.replace("\"", "").split(",")
|
100 |
+
end_times = list(map(float, end_times.replace("\"", "").split(",")))
|
101 |
+
|
102 |
+
# Process each sub-utterance
|
103 |
+
wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
|
104 |
+
for i, (wav, text) in enumerate(zip(wavs, texts)):
|
105 |
+
sub_basename = "%s_%02d" % (wav_fname, i)
|
106 |
+
metadata.append(process_utterance(wav, text, out_dir, sub_basename,
|
107 |
+
skip_existing, hparams))
|
108 |
+
|
109 |
+
return [m for m in metadata if m is not None]
|
110 |
+
|
111 |
+
|
112 |
+
def split_on_silences(wav_fpath, words, end_times, hparams):
|
113 |
+
# Load the audio waveform
|
114 |
+
wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
|
115 |
+
if hparams.rescale:
|
116 |
+
wav = wav / np.abs(wav).max() * hparams.rescaling_max
|
117 |
+
|
118 |
+
words = np.array(words)
|
119 |
+
start_times = np.array([0.0] + end_times[:-1])
|
120 |
+
end_times = np.array(end_times)
|
121 |
+
assert len(words) == len(end_times) == len(start_times)
|
122 |
+
assert words[0] == "" and words[-1] == ""
|
123 |
+
|
124 |
+
# Find pauses that are too long
|
125 |
+
mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
|
126 |
+
mask[0] = mask[-1] = True
|
127 |
+
breaks = np.where(mask)[0]
|
128 |
+
|
129 |
+
# Profile the noise from the silences and perform noise reduction on the waveform
|
130 |
+
silence_times = [[start_times[i], end_times[i]] for i in breaks]
|
131 |
+
silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
|
132 |
+
noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
|
133 |
+
if len(noisy_wav) > hparams.sample_rate * 0.02:
|
134 |
+
profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
|
135 |
+
wav = logmmse.denoise(wav, profile, eta=0)
|
136 |
+
|
137 |
+
# Re-attach segments that are too short
|
138 |
+
segments = list(zip(breaks[:-1], breaks[1:]))
|
139 |
+
segment_durations = [start_times[end] - end_times[start] for start, end in segments]
|
140 |
+
i = 0
|
141 |
+
while i < len(segments) and len(segments) > 1:
|
142 |
+
if segment_durations[i] < hparams.utterance_min_duration:
|
143 |
+
# See if the segment can be re-attached with the right or the left segment
|
144 |
+
left_duration = float("inf") if i == 0 else segment_durations[i - 1]
|
145 |
+
right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
|
146 |
+
joined_duration = segment_durations[i] + min(left_duration, right_duration)
|
147 |
+
|
148 |
+
# Do not re-attach if it causes the joined utterance to be too long
|
149 |
+
if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
|
150 |
+
i += 1
|
151 |
+
continue
|
152 |
+
|
153 |
+
# Re-attach the segment with the neighbour of shortest duration
|
154 |
+
j = i - 1 if left_duration <= right_duration else i
|
155 |
+
segments[j] = (segments[j][0], segments[j + 1][1])
|
156 |
+
segment_durations[j] = joined_duration
|
157 |
+
del segments[j + 1], segment_durations[j + 1]
|
158 |
+
else:
|
159 |
+
i += 1
|
160 |
+
|
161 |
+
# Split the utterance
|
162 |
+
segment_times = [[end_times[start], start_times[end]] for start, end in segments]
|
163 |
+
segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
|
164 |
+
wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
|
165 |
+
texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
|
166 |
+
|
167 |
+
# # DEBUG: play the audio segments (run with -n=1)
|
168 |
+
# import sounddevice as sd
|
169 |
+
# if len(wavs) > 1:
|
170 |
+
# print("This sentence was split in %d segments:" % len(wavs))
|
171 |
+
# else:
|
172 |
+
# print("There are no silences long enough for this sentence to be split:")
|
173 |
+
# for wav, text in zip(wavs, texts):
|
174 |
+
# # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
|
175 |
+
# # when playing them. You shouldn't need to do that in your parsers.
|
176 |
+
# wav = np.concatenate((wav, [0] * 16000))
|
177 |
+
# print("\t%s" % text)
|
178 |
+
# sd.play(wav, 16000, blocking=True)
|
179 |
+
# print("")
|
180 |
+
|
181 |
+
return wavs, texts
|
182 |
+
|
183 |
+
|
184 |
+
def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
|
185 |
+
skip_existing: bool, hparams):
|
186 |
+
## FOR REFERENCE:
|
187 |
+
# For you not to lose your head if you ever wish to change things here or implement your own
|
188 |
+
# synthesizer.
|
189 |
+
# - Both the audios and the mel spectrograms are saved as numpy arrays
|
190 |
+
# - There is no processing done to the audios that will be saved to disk beyond volume
|
191 |
+
# normalization (in split_on_silences)
|
192 |
+
# - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
|
193 |
+
# is why we re-apply it on the audio on the side of the vocoder.
|
194 |
+
# - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
|
195 |
+
# without extra padding. This means that you won't have an exact relation between the length
|
196 |
+
# of the wav and of the mel spectrogram. See the vocoder data loader.
|
197 |
+
|
198 |
+
|
199 |
+
# Skip existing utterances if needed
|
200 |
+
mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
|
201 |
+
wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
|
202 |
+
if skip_existing and mel_fpath.exists() and wav_fpath.exists():
|
203 |
+
return None
|
204 |
+
|
205 |
+
# Trim silence
|
206 |
+
if hparams.trim_silence:
|
207 |
+
wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
|
208 |
+
|
209 |
+
# Skip utterances that are too short
|
210 |
+
if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
|
211 |
+
return None
|
212 |
+
|
213 |
+
# Compute the mel spectrogram
|
214 |
+
mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
|
215 |
+
mel_frames = mel_spectrogram.shape[1]
|
216 |
+
|
217 |
+
# Skip utterances that are too long
|
218 |
+
if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
|
219 |
+
return None
|
220 |
+
|
221 |
+
# Write the spectrogram, embed and audio to disk
|
222 |
+
np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
|
223 |
+
np.save(wav_fpath, wav, allow_pickle=False)
|
224 |
+
|
225 |
+
# Return a tuple describing this training example
|
226 |
+
return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
|
227 |
+
|
228 |
+
|
229 |
+
def embed_utterance(fpaths, encoder_model_fpath):
|
230 |
+
if not encoder.is_loaded():
|
231 |
+
encoder.load_model(encoder_model_fpath)
|
232 |
+
|
233 |
+
# Compute the speaker embedding of the utterance
|
234 |
+
wav_fpath, embed_fpath = fpaths
|
235 |
+
wav = np.load(wav_fpath)
|
236 |
+
wav = encoder.preprocess_wav(wav)
|
237 |
+
embed = encoder.embed_utterance(wav)
|
238 |
+
np.save(embed_fpath, embed, allow_pickle=False)
|
239 |
+
|
240 |
+
|
241 |
+
def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
|
242 |
+
wav_dir = synthesizer_root.joinpath("audio")
|
243 |
+
metadata_fpath = synthesizer_root.joinpath("train.txt")
|
244 |
+
assert wav_dir.exists() and metadata_fpath.exists()
|
245 |
+
embed_dir = synthesizer_root.joinpath("embeds")
|
246 |
+
embed_dir.mkdir(exist_ok=True)
|
247 |
+
|
248 |
+
# Gather the input wave filepath and the target output embed filepath
|
249 |
+
with metadata_fpath.open("r") as metadata_file:
|
250 |
+
metadata = [line.split("|") for line in metadata_file]
|
251 |
+
fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
|
252 |
+
|
253 |
+
# TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
|
254 |
+
# Embed the utterances in separate threads
|
255 |
+
func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
|
256 |
+
job = Pool(n_processes).imap(func, fpaths)
|
257 |
+
list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
|
258 |
+
|
synthesizer/synthesize.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import platform
|
2 |
+
from functools import partial
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from synthesizer.hparams import hparams_debug_string
|
11 |
+
from synthesizer.models.tacotron import Tacotron
|
12 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
13 |
+
from synthesizer.utils import data_parallel_workaround
|
14 |
+
from synthesizer.utils.symbols import symbols
|
15 |
+
|
16 |
+
|
17 |
+
def run_synthesis(in_dir: Path, out_dir: Path, syn_model_fpath: Path, hparams):
|
18 |
+
# This generates ground truth-aligned mels for vocoder training
|
19 |
+
synth_dir = out_dir / "mels_gta"
|
20 |
+
synth_dir.mkdir(exist_ok=True, parents=True)
|
21 |
+
print(hparams_debug_string())
|
22 |
+
|
23 |
+
# Check for GPU
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
device = torch.device("cuda")
|
26 |
+
if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
|
27 |
+
raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
|
28 |
+
else:
|
29 |
+
device = torch.device("cpu")
|
30 |
+
print("Synthesizer using device:", device)
|
31 |
+
|
32 |
+
# Instantiate Tacotron model
|
33 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
34 |
+
num_chars=len(symbols),
|
35 |
+
encoder_dims=hparams.tts_encoder_dims,
|
36 |
+
decoder_dims=hparams.tts_decoder_dims,
|
37 |
+
n_mels=hparams.num_mels,
|
38 |
+
fft_bins=hparams.num_mels,
|
39 |
+
postnet_dims=hparams.tts_postnet_dims,
|
40 |
+
encoder_K=hparams.tts_encoder_K,
|
41 |
+
lstm_dims=hparams.tts_lstm_dims,
|
42 |
+
postnet_K=hparams.tts_postnet_K,
|
43 |
+
num_highways=hparams.tts_num_highways,
|
44 |
+
dropout=0., # Use zero dropout for gta mels
|
45 |
+
stop_threshold=hparams.tts_stop_threshold,
|
46 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
47 |
+
|
48 |
+
# Load the weights
|
49 |
+
print("\nLoading weights at %s" % syn_model_fpath)
|
50 |
+
model.load(syn_model_fpath)
|
51 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
52 |
+
|
53 |
+
# Synthesize using same reduction factor as the model is currently trained
|
54 |
+
r = np.int32(model.r)
|
55 |
+
|
56 |
+
# Set model to eval mode (disable gradient and zoneout)
|
57 |
+
model.eval()
|
58 |
+
|
59 |
+
# Initialize the dataset
|
60 |
+
metadata_fpath = in_dir.joinpath("train.txt")
|
61 |
+
mel_dir = in_dir.joinpath("mels")
|
62 |
+
embed_dir = in_dir.joinpath("embeds")
|
63 |
+
|
64 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
65 |
+
collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
|
66 |
+
data_loader = DataLoader(dataset, hparams.synthesis_batch_size, collate_fn=collate_fn, num_workers=2)
|
67 |
+
|
68 |
+
# Generate GTA mels
|
69 |
+
meta_out_fpath = out_dir / "synthesized.txt"
|
70 |
+
with meta_out_fpath.open("w") as file:
|
71 |
+
for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
|
72 |
+
texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)
|
73 |
+
|
74 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
75 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
76 |
+
_, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
|
77 |
+
else:
|
78 |
+
_, mels_out, _, _ = model(texts, mels, embeds)
|
79 |
+
|
80 |
+
for j, k in enumerate(idx):
|
81 |
+
# Note: outputs mel-spectrogram files and target ones have same names, just different folders
|
82 |
+
mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
|
83 |
+
mel_out = mels_out[j].detach().cpu().numpy().T
|
84 |
+
|
85 |
+
# Use the length of the ground truth mel to remove padding from the generated mels
|
86 |
+
mel_out = mel_out[:int(dataset.metadata[k][4])]
|
87 |
+
|
88 |
+
# Write the spectrogram to disk
|
89 |
+
np.save(mel_filename, mel_out, allow_pickle=False)
|
90 |
+
|
91 |
+
# Write metadata into the synthesized file
|
92 |
+
file.write("|".join(dataset.metadata[k]))
|
synthesizer/synthesizer_dataset.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
from synthesizer.utils.text import text_to_sequence
|
6 |
+
|
7 |
+
|
8 |
+
class SynthesizerDataset(Dataset):
|
9 |
+
def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
|
10 |
+
print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
|
11 |
+
|
12 |
+
with metadata_fpath.open("r") as metadata_file:
|
13 |
+
metadata = [line.split("|") for line in metadata_file]
|
14 |
+
|
15 |
+
mel_fnames = [x[1] for x in metadata if int(x[4])]
|
16 |
+
mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
|
17 |
+
embed_fnames = [x[2] for x in metadata if int(x[4])]
|
18 |
+
embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
|
19 |
+
self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
|
20 |
+
self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
|
21 |
+
self.metadata = metadata
|
22 |
+
self.hparams = hparams
|
23 |
+
|
24 |
+
print("Found %d samples" % len(self.samples_fpaths))
|
25 |
+
|
26 |
+
def __getitem__(self, index):
|
27 |
+
# Sometimes index may be a list of 2 (not sure why this happens)
|
28 |
+
# If that is the case, return a single item corresponding to first element in index
|
29 |
+
if index is list:
|
30 |
+
index = index[0]
|
31 |
+
|
32 |
+
mel_path, embed_path = self.samples_fpaths[index]
|
33 |
+
mel = np.load(mel_path).T.astype(np.float32)
|
34 |
+
|
35 |
+
# Load the embed
|
36 |
+
embed = np.load(embed_path)
|
37 |
+
|
38 |
+
# Get the text and clean it
|
39 |
+
text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
|
40 |
+
|
41 |
+
# Convert the list returned by text_to_sequence to a numpy array
|
42 |
+
text = np.asarray(text).astype(np.int32)
|
43 |
+
|
44 |
+
return text, mel.astype(np.float32), embed.astype(np.float32), index
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.samples_fpaths)
|
48 |
+
|
49 |
+
|
50 |
+
def collate_synthesizer(batch, r, hparams):
|
51 |
+
# Text
|
52 |
+
x_lens = [len(x[0]) for x in batch]
|
53 |
+
max_x_len = max(x_lens)
|
54 |
+
|
55 |
+
chars = [pad1d(x[0], max_x_len) for x in batch]
|
56 |
+
chars = np.stack(chars)
|
57 |
+
|
58 |
+
# Mel spectrogram
|
59 |
+
spec_lens = [x[1].shape[-1] for x in batch]
|
60 |
+
max_spec_len = max(spec_lens) + 1
|
61 |
+
if max_spec_len % r != 0:
|
62 |
+
max_spec_len += r - max_spec_len % r
|
63 |
+
|
64 |
+
# WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
|
65 |
+
# By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
|
66 |
+
if hparams.symmetric_mels:
|
67 |
+
mel_pad_value = -1 * hparams.max_abs_value
|
68 |
+
else:
|
69 |
+
mel_pad_value = 0
|
70 |
+
|
71 |
+
mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
|
72 |
+
mel = np.stack(mel)
|
73 |
+
|
74 |
+
# Speaker embedding (SV2TTS)
|
75 |
+
embeds = np.array([x[2] for x in batch])
|
76 |
+
|
77 |
+
# Index (for vocoder preprocessing)
|
78 |
+
indices = [x[3] for x in batch]
|
79 |
+
|
80 |
+
|
81 |
+
# Convert all to tensor
|
82 |
+
chars = torch.tensor(chars).long()
|
83 |
+
mel = torch.tensor(mel)
|
84 |
+
embeds = torch.tensor(embeds)
|
85 |
+
|
86 |
+
return chars, mel, embeds, indices
|
87 |
+
|
88 |
+
def pad1d(x, max_len, pad_value=0):
|
89 |
+
return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
|
90 |
+
|
91 |
+
def pad2d(x, max_len, pad_value=0):
|
92 |
+
return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
|
synthesizer/train.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from functools import partial
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import optim
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
from synthesizer import audio
|
11 |
+
from synthesizer.models.tacotron import Tacotron
|
12 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
13 |
+
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
14 |
+
from synthesizer.utils.plot import plot_spectrogram
|
15 |
+
from synthesizer.utils.symbols import symbols
|
16 |
+
from synthesizer.utils.text import sequence_to_text
|
17 |
+
from vocoder.display import *
|
18 |
+
|
19 |
+
|
20 |
+
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
def time_string():
|
24 |
+
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
25 |
+
|
26 |
+
|
27 |
+
def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
|
28 |
+
hparams):
|
29 |
+
models_dir.mkdir(exist_ok=True)
|
30 |
+
|
31 |
+
model_dir = models_dir.joinpath(run_id)
|
32 |
+
plot_dir = model_dir.joinpath("plots")
|
33 |
+
wav_dir = model_dir.joinpath("wavs")
|
34 |
+
mel_output_dir = model_dir.joinpath("mel-spectrograms")
|
35 |
+
meta_folder = model_dir.joinpath("metas")
|
36 |
+
model_dir.mkdir(exist_ok=True)
|
37 |
+
plot_dir.mkdir(exist_ok=True)
|
38 |
+
wav_dir.mkdir(exist_ok=True)
|
39 |
+
mel_output_dir.mkdir(exist_ok=True)
|
40 |
+
meta_folder.mkdir(exist_ok=True)
|
41 |
+
|
42 |
+
weights_fpath = model_dir / f"synthesizer.pt"
|
43 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
44 |
+
|
45 |
+
print("Checkpoint path: {}".format(weights_fpath))
|
46 |
+
print("Loading training data from: {}".format(metadata_fpath))
|
47 |
+
print("Using model: Tacotron")
|
48 |
+
|
49 |
+
# Bookkeeping
|
50 |
+
time_window = ValueWindow(100)
|
51 |
+
loss_window = ValueWindow(100)
|
52 |
+
|
53 |
+
# From WaveRNN/train_tacotron.py
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
device = torch.device("cuda")
|
56 |
+
|
57 |
+
for session in hparams.tts_schedule:
|
58 |
+
_, _, _, batch_size = session
|
59 |
+
if batch_size % torch.cuda.device_count() != 0:
|
60 |
+
raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
|
61 |
+
else:
|
62 |
+
device = torch.device("cpu")
|
63 |
+
print("Using device:", device)
|
64 |
+
|
65 |
+
# Instantiate Tacotron Model
|
66 |
+
print("\nInitialising Tacotron Model...\n")
|
67 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
68 |
+
num_chars=len(symbols),
|
69 |
+
encoder_dims=hparams.tts_encoder_dims,
|
70 |
+
decoder_dims=hparams.tts_decoder_dims,
|
71 |
+
n_mels=hparams.num_mels,
|
72 |
+
fft_bins=hparams.num_mels,
|
73 |
+
postnet_dims=hparams.tts_postnet_dims,
|
74 |
+
encoder_K=hparams.tts_encoder_K,
|
75 |
+
lstm_dims=hparams.tts_lstm_dims,
|
76 |
+
postnet_K=hparams.tts_postnet_K,
|
77 |
+
num_highways=hparams.tts_num_highways,
|
78 |
+
dropout=hparams.tts_dropout,
|
79 |
+
stop_threshold=hparams.tts_stop_threshold,
|
80 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
81 |
+
|
82 |
+
# Initialize the optimizer
|
83 |
+
optimizer = optim.Adam(model.parameters())
|
84 |
+
|
85 |
+
# Load the weights
|
86 |
+
if force_restart or not weights_fpath.exists():
|
87 |
+
print("\nStarting the training of Tacotron from scratch\n")
|
88 |
+
model.save(weights_fpath)
|
89 |
+
|
90 |
+
# Embeddings metadata
|
91 |
+
char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
|
92 |
+
with open(char_embedding_fpath, "w", encoding="utf-8") as f:
|
93 |
+
for symbol in symbols:
|
94 |
+
if symbol == " ":
|
95 |
+
symbol = "\\s" # For visual purposes, swap space with \s
|
96 |
+
|
97 |
+
f.write("{}\n".format(symbol))
|
98 |
+
|
99 |
+
else:
|
100 |
+
print("\nLoading weights at %s" % weights_fpath)
|
101 |
+
model.load(weights_fpath, optimizer)
|
102 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
103 |
+
|
104 |
+
# Initialize the dataset
|
105 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
106 |
+
mel_dir = syn_dir.joinpath("mels")
|
107 |
+
embed_dir = syn_dir.joinpath("embeds")
|
108 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
109 |
+
|
110 |
+
for i, session in enumerate(hparams.tts_schedule):
|
111 |
+
current_step = model.get_step()
|
112 |
+
|
113 |
+
r, lr, max_step, batch_size = session
|
114 |
+
|
115 |
+
training_steps = max_step - current_step
|
116 |
+
|
117 |
+
# Do we need to change to the next session?
|
118 |
+
if current_step >= max_step:
|
119 |
+
# Are there no further sessions than the current one?
|
120 |
+
if i == len(hparams.tts_schedule) - 1:
|
121 |
+
# We have completed training. Save the model and exit
|
122 |
+
model.save(weights_fpath, optimizer)
|
123 |
+
break
|
124 |
+
else:
|
125 |
+
# There is a following session, go to it
|
126 |
+
continue
|
127 |
+
|
128 |
+
model.r = r
|
129 |
+
|
130 |
+
# Begin the training
|
131 |
+
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
132 |
+
("Batch Size", batch_size),
|
133 |
+
("Learning Rate", lr),
|
134 |
+
("Outputs/Step (r)", model.r)])
|
135 |
+
|
136 |
+
for p in optimizer.param_groups:
|
137 |
+
p["lr"] = lr
|
138 |
+
|
139 |
+
collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
|
140 |
+
data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
|
141 |
+
|
142 |
+
total_iters = len(dataset)
|
143 |
+
steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
|
144 |
+
epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
|
145 |
+
|
146 |
+
for epoch in range(1, epochs+1):
|
147 |
+
for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
|
148 |
+
start_time = time.time()
|
149 |
+
|
150 |
+
# Generate stop tokens for training
|
151 |
+
stop = torch.ones(mels.shape[0], mels.shape[2])
|
152 |
+
for j, k in enumerate(idx):
|
153 |
+
stop[j, :int(dataset.metadata[k][4])-1] = 0
|
154 |
+
|
155 |
+
texts = texts.to(device)
|
156 |
+
mels = mels.to(device)
|
157 |
+
embeds = embeds.to(device)
|
158 |
+
stop = stop.to(device)
|
159 |
+
|
160 |
+
# Forward pass
|
161 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
162 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
163 |
+
m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
|
164 |
+
else:
|
165 |
+
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
|
166 |
+
|
167 |
+
# Backward pass
|
168 |
+
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
|
169 |
+
m2_loss = F.mse_loss(m2_hat, mels)
|
170 |
+
stop_loss = F.binary_cross_entropy(stop_pred, stop)
|
171 |
+
|
172 |
+
loss = m1_loss + m2_loss + stop_loss
|
173 |
+
|
174 |
+
optimizer.zero_grad()
|
175 |
+
loss.backward()
|
176 |
+
|
177 |
+
if hparams.tts_clip_grad_norm is not None:
|
178 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
|
179 |
+
if np.isnan(grad_norm.cpu()):
|
180 |
+
print("grad_norm was NaN!")
|
181 |
+
|
182 |
+
optimizer.step()
|
183 |
+
|
184 |
+
time_window.append(time.time() - start_time)
|
185 |
+
loss_window.append(loss.item())
|
186 |
+
|
187 |
+
step = model.get_step()
|
188 |
+
k = step // 1000
|
189 |
+
|
190 |
+
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
|
191 |
+
f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
|
192 |
+
stream(msg)
|
193 |
+
|
194 |
+
# Backup or save model as appropriate
|
195 |
+
if backup_every != 0 and step % backup_every == 0 :
|
196 |
+
backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
|
197 |
+
model.save(backup_fpath, optimizer)
|
198 |
+
|
199 |
+
if save_every != 0 and step % save_every == 0 :
|
200 |
+
# Must save latest optimizer state to ensure that resuming training
|
201 |
+
# doesn't produce artifacts
|
202 |
+
model.save(weights_fpath, optimizer)
|
203 |
+
|
204 |
+
# Evaluate model to generate samples
|
205 |
+
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
206 |
+
step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
|
207 |
+
if epoch_eval or step_eval:
|
208 |
+
for sample_idx in range(hparams.tts_eval_num_samples):
|
209 |
+
# At most, generate samples equal to number in the batch
|
210 |
+
if sample_idx + 1 <= len(texts):
|
211 |
+
# Remove padding from mels using frame length in metadata
|
212 |
+
mel_length = int(dataset.metadata[idx[sample_idx]][4])
|
213 |
+
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
214 |
+
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
215 |
+
attention_len = mel_length // model.r
|
216 |
+
|
217 |
+
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
218 |
+
mel_prediction=mel_prediction,
|
219 |
+
target_spectrogram=target_spectrogram,
|
220 |
+
input_seq=np_now(texts[sample_idx]),
|
221 |
+
step=step,
|
222 |
+
plot_dir=plot_dir,
|
223 |
+
mel_output_dir=mel_output_dir,
|
224 |
+
wav_dir=wav_dir,
|
225 |
+
sample_num=sample_idx + 1,
|
226 |
+
loss=loss,
|
227 |
+
hparams=hparams)
|
228 |
+
|
229 |
+
# Break out of loop to update training schedule
|
230 |
+
if step >= max_step:
|
231 |
+
break
|
232 |
+
|
233 |
+
# Add line break after every epoch
|
234 |
+
print("")
|
235 |
+
|
236 |
+
|
237 |
+
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
238 |
+
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
|
239 |
+
# Save some results for evaluation
|
240 |
+
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
241 |
+
save_attention(attention, attention_path)
|
242 |
+
|
243 |
+
# save predicted mel spectrogram to disk (debug)
|
244 |
+
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
245 |
+
np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
|
246 |
+
|
247 |
+
# save griffin lim inverted wav for debug (mel -> wav)
|
248 |
+
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
|
249 |
+
wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
|
250 |
+
audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
|
251 |
+
|
252 |
+
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
253 |
+
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
254 |
+
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
255 |
+
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
256 |
+
target_spectrogram=target_spectrogram,
|
257 |
+
max_len=target_spectrogram.size // hparams.num_mels)
|
258 |
+
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|
synthesizer/utils/__init__.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
_output_ref = None
|
5 |
+
_replicas_ref = None
|
6 |
+
|
7 |
+
def data_parallel_workaround(model, *input):
|
8 |
+
global _output_ref
|
9 |
+
global _replicas_ref
|
10 |
+
device_ids = list(range(torch.cuda.device_count()))
|
11 |
+
output_device = device_ids[0]
|
12 |
+
replicas = torch.nn.parallel.replicate(model, device_ids)
|
13 |
+
# input.shape = (num_args, batch, ...)
|
14 |
+
inputs = torch.nn.parallel.scatter(input, device_ids)
|
15 |
+
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
16 |
+
replicas = replicas[:len(inputs)]
|
17 |
+
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
18 |
+
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
19 |
+
_output_ref = outputs
|
20 |
+
_replicas_ref = replicas
|
21 |
+
return y_hat
|
22 |
+
|
23 |
+
|
24 |
+
class ValueWindow():
|
25 |
+
def __init__(self, window_size=100):
|
26 |
+
self._window_size = window_size
|
27 |
+
self._values = []
|
28 |
+
|
29 |
+
def append(self, x):
|
30 |
+
self._values = self._values[-(self._window_size - 1):] + [x]
|
31 |
+
|
32 |
+
@property
|
33 |
+
def sum(self):
|
34 |
+
return sum(self._values)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def count(self):
|
38 |
+
return len(self._values)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def average(self):
|
42 |
+
return self.sum / max(1, self.count)
|
43 |
+
|
44 |
+
def reset(self):
|
45 |
+
self._values = []
|
synthesizer/utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.68 kB). View file
|
|
synthesizer/utils/__pycache__/cleaners.cpython-37.pyc
ADDED
Binary file (2.81 kB). View file
|
|
synthesizer/utils/__pycache__/numbers.cpython-37.pyc
ADDED
Binary file (2.18 kB). View file
|
|
synthesizer/utils/__pycache__/symbols.cpython-37.pyc
ADDED
Binary file (582 Bytes). View file
|
|
synthesizer/utils/__pycache__/text.cpython-37.pyc
ADDED
Binary file (2.71 kB). View file
|
|
synthesizer/utils/_cmudict.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
valid_symbols = [
|
4 |
+
"AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
|
5 |
+
"AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
|
6 |
+
"B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
|
7 |
+
"EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
|
8 |
+
"IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
|
9 |
+
"OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
|
10 |
+
"UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
|
11 |
+
]
|
12 |
+
|
13 |
+
_valid_symbol_set = set(valid_symbols)
|
14 |
+
|
15 |
+
|
16 |
+
class CMUDict:
|
17 |
+
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
18 |
+
def __init__(self, file_or_path, keep_ambiguous=True):
|
19 |
+
if isinstance(file_or_path, str):
|
20 |
+
with open(file_or_path, encoding="latin-1") as f:
|
21 |
+
entries = _parse_cmudict(f)
|
22 |
+
else:
|
23 |
+
entries = _parse_cmudict(file_or_path)
|
24 |
+
if not keep_ambiguous:
|
25 |
+
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
26 |
+
self._entries = entries
|
27 |
+
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self._entries)
|
31 |
+
|
32 |
+
|
33 |
+
def lookup(self, word):
|
34 |
+
"""Returns list of ARPAbet pronunciations of the given word."""
|
35 |
+
return self._entries.get(word.upper())
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
_alt_re = re.compile(r"\([0-9]+\)")
|
40 |
+
|
41 |
+
|
42 |
+
def _parse_cmudict(file):
|
43 |
+
cmudict = {}
|
44 |
+
for line in file:
|
45 |
+
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
46 |
+
parts = line.split(" ")
|
47 |
+
word = re.sub(_alt_re, "", parts[0])
|
48 |
+
pronunciation = _get_pronunciation(parts[1])
|
49 |
+
if pronunciation:
|
50 |
+
if word in cmudict:
|
51 |
+
cmudict[word].append(pronunciation)
|
52 |
+
else:
|
53 |
+
cmudict[word] = [pronunciation]
|
54 |
+
return cmudict
|
55 |
+
|
56 |
+
|
57 |
+
def _get_pronunciation(s):
|
58 |
+
parts = s.strip().split(" ")
|
59 |
+
for part in parts:
|
60 |
+
if part not in _valid_symbol_set:
|
61 |
+
return None
|
62 |
+
return " ".join(parts)
|
synthesizer/utils/cleaners.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
3 |
+
|
4 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
5 |
+
hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
|
6 |
+
1. "english_cleaners" for English text
|
7 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
8 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
9 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
10 |
+
the symbols in symbols.py to match your data).
|
11 |
+
"""
|
12 |
+
import re
|
13 |
+
from unidecode import unidecode
|
14 |
+
from synthesizer.utils.numbers import normalize_numbers
|
15 |
+
|
16 |
+
|
17 |
+
# Regular expression matching whitespace:
|
18 |
+
_whitespace_re = re.compile(r"\s+")
|
19 |
+
|
20 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
21 |
+
_abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
|
22 |
+
("mrs", "misess"),
|
23 |
+
("mr", "mister"),
|
24 |
+
("dr", "doctor"),
|
25 |
+
("st", "saint"),
|
26 |
+
("co", "company"),
|
27 |
+
("jr", "junior"),
|
28 |
+
("maj", "major"),
|
29 |
+
("gen", "general"),
|
30 |
+
("drs", "doctors"),
|
31 |
+
("rev", "reverend"),
|
32 |
+
("lt", "lieutenant"),
|
33 |
+
("hon", "honorable"),
|
34 |
+
("sgt", "sergeant"),
|
35 |
+
("capt", "captain"),
|
36 |
+
("esq", "esquire"),
|
37 |
+
("ltd", "limited"),
|
38 |
+
("col", "colonel"),
|
39 |
+
("ft", "fort"),
|
40 |
+
]]
|
41 |
+
|
42 |
+
|
43 |
+
def expand_abbreviations(text):
|
44 |
+
for regex, replacement in _abbreviations:
|
45 |
+
text = re.sub(regex, replacement, text)
|
46 |
+
return text
|
47 |
+
|
48 |
+
|
49 |
+
def expand_numbers(text):
|
50 |
+
return normalize_numbers(text)
|
51 |
+
|
52 |
+
|
53 |
+
def lowercase(text):
|
54 |
+
"""lowercase input tokens."""
|
55 |
+
return text.lower()
|
56 |
+
|
57 |
+
|
58 |
+
def collapse_whitespace(text):
|
59 |
+
return re.sub(_whitespace_re, " ", text)
|
60 |
+
|
61 |
+
|
62 |
+
def convert_to_ascii(text):
|
63 |
+
return unidecode(text)
|
64 |
+
|
65 |
+
|
66 |
+
def basic_cleaners(text):
|
67 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
68 |
+
text = lowercase(text)
|
69 |
+
text = collapse_whitespace(text)
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
def transliteration_cleaners(text):
|
74 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
75 |
+
text = convert_to_ascii(text)
|
76 |
+
text = lowercase(text)
|
77 |
+
text = collapse_whitespace(text)
|
78 |
+
return text
|
79 |
+
|
80 |
+
|
81 |
+
def english_cleaners(text):
|
82 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
83 |
+
text = convert_to_ascii(text)
|
84 |
+
text = lowercase(text)
|
85 |
+
text = expand_numbers(text)
|
86 |
+
text = expand_abbreviations(text)
|
87 |
+
text = collapse_whitespace(text)
|
88 |
+
return text
|
synthesizer/utils/numbers.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import inflect
|
3 |
+
|
4 |
+
|
5 |
+
_inflect = inflect.engine()
|
6 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
7 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
8 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
9 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
10 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
11 |
+
_number_re = re.compile(r"[0-9]+")
|
12 |
+
|
13 |
+
|
14 |
+
def _remove_commas(m):
|
15 |
+
return m.group(1).replace(",", "")
|
16 |
+
|
17 |
+
|
18 |
+
def _expand_decimal_point(m):
|
19 |
+
return m.group(1).replace(".", " point ")
|
20 |
+
|
21 |
+
|
22 |
+
def _expand_dollars(m):
|
23 |
+
match = m.group(1)
|
24 |
+
parts = match.split(".")
|
25 |
+
if len(parts) > 2:
|
26 |
+
return match + " dollars" # Unexpected format
|
27 |
+
dollars = int(parts[0]) if parts[0] else 0
|
28 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
29 |
+
if dollars and cents:
|
30 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
31 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
32 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
33 |
+
elif dollars:
|
34 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
35 |
+
return "%s %s" % (dollars, dollar_unit)
|
36 |
+
elif cents:
|
37 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
38 |
+
return "%s %s" % (cents, cent_unit)
|
39 |
+
else:
|
40 |
+
return "zero dollars"
|
41 |
+
|
42 |
+
|
43 |
+
def _expand_ordinal(m):
|
44 |
+
return _inflect.number_to_words(m.group(0))
|
45 |
+
|
46 |
+
|
47 |
+
def _expand_number(m):
|
48 |
+
num = int(m.group(0))
|
49 |
+
if num > 1000 and num < 3000:
|
50 |
+
if num == 2000:
|
51 |
+
return "two thousand"
|
52 |
+
elif num > 2000 and num < 2010:
|
53 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
54 |
+
elif num % 100 == 0:
|
55 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
56 |
+
else:
|
57 |
+
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
58 |
+
else:
|
59 |
+
return _inflect.number_to_words(num, andword="")
|
60 |
+
|
61 |
+
|
62 |
+
def normalize_numbers(text):
|
63 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
64 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
65 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
66 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
67 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
68 |
+
text = re.sub(_number_re, _expand_number, text)
|
69 |
+
return text
|
synthesizer/utils/plot.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def split_title_line(title_text, max_words=5):
|
5 |
+
"""
|
6 |
+
A function that splits any string based on specific character
|
7 |
+
(returning it with the string), with maximum number of words on it
|
8 |
+
"""
|
9 |
+
seq = title_text.split()
|
10 |
+
return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
|
11 |
+
|
12 |
+
|
13 |
+
def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
|
14 |
+
import matplotlib
|
15 |
+
matplotlib.use("Agg")
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
if max_len is not None:
|
19 |
+
alignment = alignment[:, :max_len]
|
20 |
+
|
21 |
+
fig = plt.figure(figsize=(8, 6))
|
22 |
+
ax = fig.add_subplot(111)
|
23 |
+
|
24 |
+
im = ax.imshow(
|
25 |
+
alignment,
|
26 |
+
aspect="auto",
|
27 |
+
origin="lower",
|
28 |
+
interpolation="none")
|
29 |
+
fig.colorbar(im, ax=ax)
|
30 |
+
xlabel = "Decoder timestep"
|
31 |
+
|
32 |
+
if split_title:
|
33 |
+
title = split_title_line(title)
|
34 |
+
|
35 |
+
plt.xlabel(xlabel)
|
36 |
+
plt.title(title)
|
37 |
+
plt.ylabel("Encoder timestep")
|
38 |
+
plt.tight_layout()
|
39 |
+
plt.savefig(path, format="png")
|
40 |
+
plt.close()
|
41 |
+
|
42 |
+
|
43 |
+
def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
|
44 |
+
import matplotlib
|
45 |
+
matplotlib.use("Agg")
|
46 |
+
import matplotlib.pyplot as plt
|
47 |
+
|
48 |
+
if max_len is not None:
|
49 |
+
target_spectrogram = target_spectrogram[:max_len]
|
50 |
+
pred_spectrogram = pred_spectrogram[:max_len]
|
51 |
+
|
52 |
+
if split_title:
|
53 |
+
title = split_title_line(title)
|
54 |
+
|
55 |
+
fig = plt.figure(figsize=(10, 8))
|
56 |
+
# Set common labels
|
57 |
+
fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
|
58 |
+
|
59 |
+
#target spectrogram subplot
|
60 |
+
if target_spectrogram is not None:
|
61 |
+
ax1 = fig.add_subplot(311)
|
62 |
+
ax2 = fig.add_subplot(312)
|
63 |
+
|
64 |
+
if auto_aspect:
|
65 |
+
im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
|
66 |
+
else:
|
67 |
+
im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
|
68 |
+
ax1.set_title("Target Mel-Spectrogram")
|
69 |
+
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
|
70 |
+
ax2.set_title("Predicted Mel-Spectrogram")
|
71 |
+
else:
|
72 |
+
ax2 = fig.add_subplot(211)
|
73 |
+
|
74 |
+
if auto_aspect:
|
75 |
+
im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
|
76 |
+
else:
|
77 |
+
im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
|
78 |
+
fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
|
79 |
+
|
80 |
+
plt.tight_layout()
|
81 |
+
plt.savefig(path, format="png")
|
82 |
+
plt.close()
|
synthesizer/utils/symbols.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines the set of symbols used in text input to the model.
|
3 |
+
|
4 |
+
The default is a set of ASCII characters that works well for English or text that has been run
|
5 |
+
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
6 |
+
"""
|
7 |
+
# from . import cmudict
|
8 |
+
|
9 |
+
_pad = "_"
|
10 |
+
_eos = "~"
|
11 |
+
|
12 |
+
# for zh
|
13 |
+
# _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? "
|
14 |
+
|
15 |
+
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
|
16 |
+
|
17 |
+
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
18 |
+
#_arpabet = ["@' + s for s in cmudict.valid_symbols]
|
19 |
+
|
20 |
+
# Export all symbols:
|
21 |
+
symbols = [_pad, _eos] + list(_characters) #+ _arpabet
|
synthesizer/utils/text.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from synthesizer.utils.symbols import symbols
|
2 |
+
from synthesizer.utils import cleaners
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
# Regular expression matching text enclosed in curly braces:
|
11 |
+
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
12 |
+
|
13 |
+
|
14 |
+
def text_to_sequence(text, cleaner_names):
|
15 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
16 |
+
|
17 |
+
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
18 |
+
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
19 |
+
|
20 |
+
Args:
|
21 |
+
text: string to convert to a sequence
|
22 |
+
cleaner_names: names of the cleaner functions to run the text through
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
List of integers corresponding to the symbols in the text
|
26 |
+
"""
|
27 |
+
sequence = []
|
28 |
+
|
29 |
+
# Check for curly braces and treat their contents as ARPAbet:
|
30 |
+
while len(text):
|
31 |
+
m = _curly_re.match(text)
|
32 |
+
if not m:
|
33 |
+
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
34 |
+
break
|
35 |
+
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
36 |
+
sequence += _arpabet_to_sequence(m.group(2))
|
37 |
+
text = m.group(3)
|
38 |
+
|
39 |
+
# Append EOS token
|
40 |
+
sequence.append(_symbol_to_id["~"])
|
41 |
+
return sequence
|
42 |
+
|
43 |
+
|
44 |
+
def sequence_to_text(sequence):
|
45 |
+
"""Converts a sequence of IDs back to a string"""
|
46 |
+
result = ""
|
47 |
+
for symbol_id in sequence:
|
48 |
+
if symbol_id in _id_to_symbol:
|
49 |
+
s = _id_to_symbol[symbol_id]
|
50 |
+
# Enclose ARPAbet back in curly braces:
|
51 |
+
if len(s) > 1 and s[0] == "@":
|
52 |
+
s = "{%s}" % s[1:]
|
53 |
+
result += s
|
54 |
+
return result.replace("}{", " ")
|
55 |
+
|
56 |
+
|
57 |
+
def _clean_text(text, cleaner_names):
|
58 |
+
for name in cleaner_names:
|
59 |
+
cleaner = getattr(cleaners, name)
|
60 |
+
if not cleaner:
|
61 |
+
raise Exception("Unknown cleaner: %s" % name)
|
62 |
+
text = cleaner(text)
|
63 |
+
return text
|
64 |
+
|
65 |
+
|
66 |
+
def _symbols_to_sequence(symbols):
|
67 |
+
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
68 |
+
|
69 |
+
|
70 |
+
def _arpabet_to_sequence(text):
|
71 |
+
return _symbols_to_sequence(["@" + s for s in text.split()])
|
72 |
+
|
73 |
+
|
74 |
+
def _should_keep_symbol(s):
|
75 |
+
return s in _symbol_to_id and s not in ("_", "~")
|
toolbox/__init__.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
from pathlib import Path
|
4 |
+
from time import perf_counter as timer
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from encoder import inference as encoder
|
10 |
+
from synthesizer.inference import Synthesizer
|
11 |
+
from toolbox.ui import UI
|
12 |
+
from toolbox.utterance import Utterance
|
13 |
+
from vocoder import inference as vocoder
|
14 |
+
|
15 |
+
|
16 |
+
# Use this directory structure for your datasets, or modify it to fit your needs
|
17 |
+
recognized_datasets = [
|
18 |
+
"LibriSpeech/dev-clean",
|
19 |
+
"LibriSpeech/dev-other",
|
20 |
+
"LibriSpeech/test-clean",
|
21 |
+
"LibriSpeech/test-other",
|
22 |
+
"LibriSpeech/train-clean-100",
|
23 |
+
"LibriSpeech/train-clean-360",
|
24 |
+
"LibriSpeech/train-other-500",
|
25 |
+
"LibriTTS/dev-clean",
|
26 |
+
"LibriTTS/dev-other",
|
27 |
+
"LibriTTS/test-clean",
|
28 |
+
"LibriTTS/test-other",
|
29 |
+
"LibriTTS/train-clean-100",
|
30 |
+
"LibriTTS/train-clean-360",
|
31 |
+
"LibriTTS/train-other-500",
|
32 |
+
"LJSpeech-1.1",
|
33 |
+
"VoxCeleb1/wav",
|
34 |
+
"VoxCeleb1/test_wav",
|
35 |
+
"VoxCeleb2/dev/aac",
|
36 |
+
"VoxCeleb2/test/aac",
|
37 |
+
"VCTK-Corpus/wav48",
|
38 |
+
]
|
39 |
+
|
40 |
+
# Maximum of generated wavs to keep on memory
|
41 |
+
MAX_WAVS = 15
|
42 |
+
|
43 |
+
|
44 |
+
class Toolbox:
|
45 |
+
def __init__(self, datasets_root: Path, models_dir: Path, seed: int=None):
|
46 |
+
sys.excepthook = self.excepthook
|
47 |
+
self.datasets_root = datasets_root
|
48 |
+
self.utterances = set()
|
49 |
+
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
50 |
+
|
51 |
+
self.synthesizer = None # type: Synthesizer
|
52 |
+
self.current_wav = None
|
53 |
+
self.waves_list = []
|
54 |
+
self.waves_count = 0
|
55 |
+
self.waves_namelist = []
|
56 |
+
|
57 |
+
# Check for webrtcvad (enables removal of silences in vocoder output)
|
58 |
+
try:
|
59 |
+
import webrtcvad
|
60 |
+
self.trim_silences = True
|
61 |
+
except:
|
62 |
+
self.trim_silences = False
|
63 |
+
|
64 |
+
# Initialize the events and the interface
|
65 |
+
self.ui = UI()
|
66 |
+
self.reset_ui(models_dir, seed)
|
67 |
+
self.setup_events()
|
68 |
+
self.ui.start()
|
69 |
+
|
70 |
+
def excepthook(self, exc_type, exc_value, exc_tb):
|
71 |
+
traceback.print_exception(exc_type, exc_value, exc_tb)
|
72 |
+
self.ui.log("Exception: %s" % exc_value)
|
73 |
+
|
74 |
+
def setup_events(self):
|
75 |
+
# Dataset, speaker and utterance selection
|
76 |
+
self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
|
77 |
+
random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
|
78 |
+
recognized_datasets,
|
79 |
+
level)
|
80 |
+
self.ui.random_dataset_button.clicked.connect(random_func(0))
|
81 |
+
self.ui.random_speaker_button.clicked.connect(random_func(1))
|
82 |
+
self.ui.random_utterance_button.clicked.connect(random_func(2))
|
83 |
+
self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
|
84 |
+
self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
|
85 |
+
|
86 |
+
# Model selection
|
87 |
+
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
88 |
+
def func():
|
89 |
+
self.synthesizer = None
|
90 |
+
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
91 |
+
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
92 |
+
|
93 |
+
# Utterance selection
|
94 |
+
func = lambda: self.load_from_browser(self.ui.browse_file())
|
95 |
+
self.ui.browser_browse_button.clicked.connect(func)
|
96 |
+
func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
|
97 |
+
self.ui.utterance_history.currentIndexChanged.connect(func)
|
98 |
+
func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
|
99 |
+
self.ui.play_button.clicked.connect(func)
|
100 |
+
self.ui.stop_button.clicked.connect(self.ui.stop)
|
101 |
+
self.ui.record_button.clicked.connect(self.record)
|
102 |
+
|
103 |
+
#Audio
|
104 |
+
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
105 |
+
|
106 |
+
#Wav playback & save
|
107 |
+
func = lambda: self.replay_last_wav()
|
108 |
+
self.ui.replay_wav_button.clicked.connect(func)
|
109 |
+
func = lambda: self.export_current_wave()
|
110 |
+
self.ui.export_wav_button.clicked.connect(func)
|
111 |
+
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
112 |
+
|
113 |
+
# Generation
|
114 |
+
func = lambda: self.synthesize() or self.vocode()
|
115 |
+
self.ui.generate_button.clicked.connect(func)
|
116 |
+
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
117 |
+
self.ui.vocode_button.clicked.connect(self.vocode)
|
118 |
+
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
119 |
+
|
120 |
+
# UMAP legend
|
121 |
+
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
122 |
+
|
123 |
+
def set_current_wav(self, index):
|
124 |
+
self.current_wav = self.waves_list[index]
|
125 |
+
|
126 |
+
def export_current_wave(self):
|
127 |
+
self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
|
128 |
+
|
129 |
+
def replay_last_wav(self):
|
130 |
+
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
131 |
+
|
132 |
+
def reset_ui(self, models_dir: Path, seed: int=None):
|
133 |
+
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
134 |
+
self.ui.populate_models(models_dir)
|
135 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
136 |
+
|
137 |
+
def load_from_browser(self, fpath=None):
|
138 |
+
if fpath is None:
|
139 |
+
fpath = Path(self.datasets_root,
|
140 |
+
self.ui.current_dataset_name,
|
141 |
+
self.ui.current_speaker_name,
|
142 |
+
self.ui.current_utterance_name)
|
143 |
+
name = str(fpath.relative_to(self.datasets_root))
|
144 |
+
speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
|
145 |
+
|
146 |
+
# Select the next utterance
|
147 |
+
if self.ui.auto_next_checkbox.isChecked():
|
148 |
+
self.ui.browser_select_next()
|
149 |
+
elif fpath == "":
|
150 |
+
return
|
151 |
+
else:
|
152 |
+
name = fpath.name
|
153 |
+
speaker_name = fpath.parent.name
|
154 |
+
|
155 |
+
# Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
|
156 |
+
# playback, so as to have a fair comparison with the generated audio
|
157 |
+
wav = Synthesizer.load_preprocess_wav(fpath)
|
158 |
+
self.ui.log("Loaded %s" % name)
|
159 |
+
|
160 |
+
self.add_real_utterance(wav, name, speaker_name)
|
161 |
+
|
162 |
+
def record(self):
|
163 |
+
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
164 |
+
if wav is None:
|
165 |
+
return
|
166 |
+
self.ui.play(wav, encoder.sampling_rate)
|
167 |
+
|
168 |
+
speaker_name = "user01"
|
169 |
+
name = speaker_name + "_rec_%05d" % np.random.randint(100000)
|
170 |
+
self.add_real_utterance(wav, name, speaker_name)
|
171 |
+
|
172 |
+
def add_real_utterance(self, wav, name, speaker_name):
|
173 |
+
# Compute the mel spectrogram
|
174 |
+
spec = Synthesizer.make_spectrogram(wav)
|
175 |
+
self.ui.draw_spec(spec, "current")
|
176 |
+
|
177 |
+
# Compute the embedding
|
178 |
+
if not encoder.is_loaded():
|
179 |
+
self.init_encoder()
|
180 |
+
encoder_wav = encoder.preprocess_wav(wav)
|
181 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
182 |
+
|
183 |
+
# Add the utterance
|
184 |
+
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
185 |
+
self.utterances.add(utterance)
|
186 |
+
self.ui.register_utterance(utterance)
|
187 |
+
|
188 |
+
# Plot it
|
189 |
+
self.ui.draw_embed(embed, name, "current")
|
190 |
+
self.ui.draw_umap_projections(self.utterances)
|
191 |
+
|
192 |
+
def clear_utterances(self):
|
193 |
+
self.utterances.clear()
|
194 |
+
self.ui.draw_umap_projections(self.utterances)
|
195 |
+
|
196 |
+
def synthesize(self):
|
197 |
+
self.ui.log("Generating the mel spectrogram...")
|
198 |
+
self.ui.set_loading(1)
|
199 |
+
|
200 |
+
# Update the synthesizer random seed
|
201 |
+
if self.ui.random_seed_checkbox.isChecked():
|
202 |
+
seed = int(self.ui.seed_textbox.text())
|
203 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
204 |
+
else:
|
205 |
+
seed = None
|
206 |
+
|
207 |
+
if seed is not None:
|
208 |
+
torch.manual_seed(seed)
|
209 |
+
|
210 |
+
# Synthesize the spectrogram
|
211 |
+
if self.synthesizer is None or seed is not None:
|
212 |
+
self.init_synthesizer()
|
213 |
+
|
214 |
+
texts = self.ui.text_prompt.toPlainText().split("\n")
|
215 |
+
embed = self.ui.selected_utterance.embed
|
216 |
+
embeds = [embed] * len(texts)
|
217 |
+
specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
|
218 |
+
breaks = [spec.shape[1] for spec in specs]
|
219 |
+
spec = np.concatenate(specs, axis=1)
|
220 |
+
|
221 |
+
self.ui.draw_spec(spec, "generated")
|
222 |
+
self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
|
223 |
+
self.ui.set_loading(0)
|
224 |
+
|
225 |
+
def vocode(self):
|
226 |
+
speaker_name, spec, breaks, _ = self.current_generated
|
227 |
+
assert spec is not None
|
228 |
+
|
229 |
+
# Initialize the vocoder model and make it determinstic, if user provides a seed
|
230 |
+
if self.ui.random_seed_checkbox.isChecked():
|
231 |
+
seed = int(self.ui.seed_textbox.text())
|
232 |
+
self.ui.populate_gen_options(seed, self.trim_silences)
|
233 |
+
else:
|
234 |
+
seed = None
|
235 |
+
|
236 |
+
if seed is not None:
|
237 |
+
torch.manual_seed(seed)
|
238 |
+
|
239 |
+
# Synthesize the waveform
|
240 |
+
if not vocoder.is_loaded() or seed is not None:
|
241 |
+
self.init_vocoder()
|
242 |
+
|
243 |
+
def vocoder_progress(i, seq_len, b_size, gen_rate):
|
244 |
+
real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
|
245 |
+
line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
|
246 |
+
% (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
|
247 |
+
self.ui.log(line, "overwrite")
|
248 |
+
self.ui.set_loading(i, seq_len)
|
249 |
+
if self.ui.current_vocoder_fpath is not None:
|
250 |
+
self.ui.log("")
|
251 |
+
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
252 |
+
else:
|
253 |
+
self.ui.log("Waveform generation with Griffin-Lim... ")
|
254 |
+
wav = Synthesizer.griffin_lim(spec)
|
255 |
+
self.ui.set_loading(0)
|
256 |
+
self.ui.log(" Done!", "append")
|
257 |
+
|
258 |
+
# Add breaks
|
259 |
+
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
260 |
+
b_starts = np.concatenate(([0], b_ends[:-1]))
|
261 |
+
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
262 |
+
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
|
263 |
+
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
264 |
+
|
265 |
+
# Trim excessive silences
|
266 |
+
if self.ui.trim_silences_checkbox.isChecked():
|
267 |
+
wav = encoder.preprocess_wav(wav)
|
268 |
+
|
269 |
+
# Play it
|
270 |
+
wav = wav / np.abs(wav).max() * 0.97
|
271 |
+
self.ui.play(wav, Synthesizer.sample_rate)
|
272 |
+
|
273 |
+
# Name it (history displayed in combobox)
|
274 |
+
# TODO better naming for the combobox items?
|
275 |
+
wav_name = str(self.waves_count + 1)
|
276 |
+
|
277 |
+
#Update waves combobox
|
278 |
+
self.waves_count += 1
|
279 |
+
if self.waves_count > MAX_WAVS:
|
280 |
+
self.waves_list.pop()
|
281 |
+
self.waves_namelist.pop()
|
282 |
+
self.waves_list.insert(0, wav)
|
283 |
+
self.waves_namelist.insert(0, wav_name)
|
284 |
+
|
285 |
+
self.ui.waves_cb.disconnect()
|
286 |
+
self.ui.waves_cb_model.setStringList(self.waves_namelist)
|
287 |
+
self.ui.waves_cb.setCurrentIndex(0)
|
288 |
+
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
289 |
+
|
290 |
+
# Update current wav
|
291 |
+
self.set_current_wav(0)
|
292 |
+
|
293 |
+
#Enable replay and save buttons:
|
294 |
+
self.ui.replay_wav_button.setDisabled(False)
|
295 |
+
self.ui.export_wav_button.setDisabled(False)
|
296 |
+
|
297 |
+
# Compute the embedding
|
298 |
+
# TODO: this is problematic with different sampling rates, gotta fix it
|
299 |
+
if not encoder.is_loaded():
|
300 |
+
self.init_encoder()
|
301 |
+
encoder_wav = encoder.preprocess_wav(wav)
|
302 |
+
embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
|
303 |
+
|
304 |
+
# Add the utterance
|
305 |
+
name = speaker_name + "_gen_%05d" % np.random.randint(100000)
|
306 |
+
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
|
307 |
+
self.utterances.add(utterance)
|
308 |
+
|
309 |
+
# Plot it
|
310 |
+
self.ui.draw_embed(embed, name, "generated")
|
311 |
+
self.ui.draw_umap_projections(self.utterances)
|
312 |
+
|
313 |
+
def init_encoder(self):
|
314 |
+
model_fpath = self.ui.current_encoder_fpath
|
315 |
+
|
316 |
+
self.ui.log("Loading the encoder %s... " % model_fpath)
|
317 |
+
self.ui.set_loading(1)
|
318 |
+
start = timer()
|
319 |
+
encoder.load_model(model_fpath)
|
320 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
321 |
+
self.ui.set_loading(0)
|
322 |
+
|
323 |
+
def init_synthesizer(self):
|
324 |
+
model_fpath = self.ui.current_synthesizer_fpath
|
325 |
+
|
326 |
+
self.ui.log("Loading the synthesizer %s... " % model_fpath)
|
327 |
+
self.ui.set_loading(1)
|
328 |
+
start = timer()
|
329 |
+
self.synthesizer = Synthesizer(model_fpath)
|
330 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
331 |
+
self.ui.set_loading(0)
|
332 |
+
|
333 |
+
def init_vocoder(self):
|
334 |
+
model_fpath = self.ui.current_vocoder_fpath
|
335 |
+
# Case of Griffin-lim
|
336 |
+
if model_fpath is None:
|
337 |
+
return
|
338 |
+
|
339 |
+
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
340 |
+
self.ui.set_loading(1)
|
341 |
+
start = timer()
|
342 |
+
vocoder.load_model(model_fpath)
|
343 |
+
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
344 |
+
self.ui.set_loading(0)
|
345 |
+
|
346 |
+
def update_seed_textbox(self):
|
347 |
+
self.ui.update_seed_textbox()
|
toolbox/ui.py
ADDED
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from pathlib import Path
|
3 |
+
from time import sleep
|
4 |
+
from typing import List, Set
|
5 |
+
from warnings import filterwarnings, warn
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import sounddevice as sd
|
10 |
+
import soundfile as sf
|
11 |
+
import umap
|
12 |
+
from PyQt5.QtCore import Qt, QStringListModel
|
13 |
+
from PyQt5.QtWidgets import *
|
14 |
+
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
15 |
+
|
16 |
+
from encoder.inference import plot_embedding_as_heatmap
|
17 |
+
from toolbox.utterance import Utterance
|
18 |
+
|
19 |
+
filterwarnings("ignore")
|
20 |
+
|
21 |
+
|
22 |
+
colormap = np.array([
|
23 |
+
[0, 127, 70],
|
24 |
+
[255, 0, 0],
|
25 |
+
[255, 217, 38],
|
26 |
+
[0, 135, 255],
|
27 |
+
[165, 0, 165],
|
28 |
+
[255, 167, 255],
|
29 |
+
[97, 142, 151],
|
30 |
+
[0, 255, 255],
|
31 |
+
[255, 96, 38],
|
32 |
+
[142, 76, 0],
|
33 |
+
[33, 0, 127],
|
34 |
+
[0, 0, 0],
|
35 |
+
[183, 183, 183],
|
36 |
+
[76, 255, 0],
|
37 |
+
], dtype=np.float) / 255
|
38 |
+
|
39 |
+
default_text = \
|
40 |
+
"Welcome to the toolbox! To begin, load an utterance from your datasets or record one " \
|
41 |
+
"yourself.\nOnce its embedding has been created, you can synthesize any text written here.\n" \
|
42 |
+
"The synthesizer expects to generate " \
|
43 |
+
"outputs that are somewhere between 5 and 12 seconds.\nTo mark breaks, write a new line. " \
|
44 |
+
"Each line will be treated separately.\nThen, they are joined together to make the final " \
|
45 |
+
"spectrogram. Use the vocoder to generate audio.\nThe vocoder generates almost in constant " \
|
46 |
+
"time, so it will be more time efficient for longer inputs like this one.\nOn the left you " \
|
47 |
+
"have the embedding projections. Load or record more utterances to see them.\nIf you have " \
|
48 |
+
"at least 2 or 3 utterances from a same speaker, a cluster should form.\nSynthesized " \
|
49 |
+
"utterances are of the same color as the speaker whose voice was used, but they're " \
|
50 |
+
"represented with a cross."
|
51 |
+
|
52 |
+
|
53 |
+
class UI(QDialog):
|
54 |
+
min_umap_points = 4
|
55 |
+
max_log_lines = 5
|
56 |
+
max_saved_utterances = 20
|
57 |
+
|
58 |
+
def draw_utterance(self, utterance: Utterance, which):
|
59 |
+
self.draw_spec(utterance.spec, which)
|
60 |
+
self.draw_embed(utterance.embed, utterance.name, which)
|
61 |
+
|
62 |
+
def draw_embed(self, embed, name, which):
|
63 |
+
embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
|
64 |
+
embed_ax.figure.suptitle("" if embed is None else name)
|
65 |
+
|
66 |
+
## Embedding
|
67 |
+
# Clear the plot
|
68 |
+
if len(embed_ax.images) > 0:
|
69 |
+
embed_ax.images[0].colorbar.remove()
|
70 |
+
embed_ax.clear()
|
71 |
+
|
72 |
+
# Draw the embed
|
73 |
+
if embed is not None:
|
74 |
+
plot_embedding_as_heatmap(embed, embed_ax)
|
75 |
+
embed_ax.set_title("embedding")
|
76 |
+
embed_ax.set_aspect("equal", "datalim")
|
77 |
+
embed_ax.set_xticks([])
|
78 |
+
embed_ax.set_yticks([])
|
79 |
+
embed_ax.figure.canvas.draw()
|
80 |
+
|
81 |
+
def draw_spec(self, spec, which):
|
82 |
+
_, spec_ax = self.current_ax if which == "current" else self.gen_ax
|
83 |
+
|
84 |
+
## Spectrogram
|
85 |
+
# Draw the spectrogram
|
86 |
+
spec_ax.clear()
|
87 |
+
if spec is not None:
|
88 |
+
spec_ax.imshow(spec, aspect="auto", interpolation="none")
|
89 |
+
spec_ax.set_title("mel spectrogram")
|
90 |
+
|
91 |
+
spec_ax.set_xticks([])
|
92 |
+
spec_ax.set_yticks([])
|
93 |
+
spec_ax.figure.canvas.draw()
|
94 |
+
if which != "current":
|
95 |
+
self.vocode_button.setDisabled(spec is None)
|
96 |
+
|
97 |
+
def draw_umap_projections(self, utterances: Set[Utterance]):
|
98 |
+
self.umap_ax.clear()
|
99 |
+
|
100 |
+
speakers = np.unique([u.speaker_name for u in utterances])
|
101 |
+
colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
|
102 |
+
embeds = [u.embed for u in utterances]
|
103 |
+
|
104 |
+
# Display a message if there aren't enough points
|
105 |
+
if len(utterances) < self.min_umap_points:
|
106 |
+
self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
|
107 |
+
(self.min_umap_points - len(utterances)),
|
108 |
+
horizontalalignment='center', fontsize=15)
|
109 |
+
self.umap_ax.set_title("")
|
110 |
+
|
111 |
+
# Compute the projections
|
112 |
+
else:
|
113 |
+
if not self.umap_hot:
|
114 |
+
self.log(
|
115 |
+
"Drawing UMAP projections for the first time, this will take a few seconds.")
|
116 |
+
self.umap_hot = True
|
117 |
+
|
118 |
+
reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
|
119 |
+
projections = reducer.fit_transform(embeds)
|
120 |
+
|
121 |
+
speakers_done = set()
|
122 |
+
for projection, utterance in zip(projections, utterances):
|
123 |
+
color = colors[utterance.speaker_name]
|
124 |
+
mark = "x" if "_gen_" in utterance.name else "o"
|
125 |
+
label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
|
126 |
+
speakers_done.add(utterance.speaker_name)
|
127 |
+
self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
|
128 |
+
label=label)
|
129 |
+
self.umap_ax.legend(prop={'size': 10})
|
130 |
+
|
131 |
+
# Draw the plot
|
132 |
+
self.umap_ax.set_aspect("equal", "datalim")
|
133 |
+
self.umap_ax.set_xticks([])
|
134 |
+
self.umap_ax.set_yticks([])
|
135 |
+
self.umap_ax.figure.canvas.draw()
|
136 |
+
|
137 |
+
def save_audio_file(self, wav, sample_rate):
|
138 |
+
dialog = QFileDialog()
|
139 |
+
dialog.setDefaultSuffix(".wav")
|
140 |
+
fpath, _ = dialog.getSaveFileName(
|
141 |
+
parent=self,
|
142 |
+
caption="Select a path to save the audio file",
|
143 |
+
filter="Audio Files (*.flac *.wav)"
|
144 |
+
)
|
145 |
+
if fpath:
|
146 |
+
#Default format is wav
|
147 |
+
if Path(fpath).suffix == "":
|
148 |
+
fpath += ".wav"
|
149 |
+
sf.write(fpath, wav, sample_rate)
|
150 |
+
|
151 |
+
def setup_audio_devices(self, sample_rate):
|
152 |
+
input_devices = []
|
153 |
+
output_devices = []
|
154 |
+
for device in sd.query_devices():
|
155 |
+
# Check if valid input
|
156 |
+
try:
|
157 |
+
sd.check_input_settings(device=device["name"], samplerate=sample_rate)
|
158 |
+
input_devices.append(device["name"])
|
159 |
+
except:
|
160 |
+
pass
|
161 |
+
|
162 |
+
# Check if valid output
|
163 |
+
try:
|
164 |
+
sd.check_output_settings(device=device["name"], samplerate=sample_rate)
|
165 |
+
output_devices.append(device["name"])
|
166 |
+
except Exception as e:
|
167 |
+
# Log a warning only if the device is not an input
|
168 |
+
if not device["name"] in input_devices:
|
169 |
+
warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
|
170 |
+
|
171 |
+
if len(input_devices) == 0:
|
172 |
+
self.log("No audio input device detected. Recording may not work.")
|
173 |
+
self.audio_in_device = None
|
174 |
+
else:
|
175 |
+
self.audio_in_device = input_devices[0]
|
176 |
+
|
177 |
+
if len(output_devices) == 0:
|
178 |
+
self.log("No supported output audio devices were found! Audio output may not work.")
|
179 |
+
self.audio_out_devices_cb.addItems(["None"])
|
180 |
+
self.audio_out_devices_cb.setDisabled(True)
|
181 |
+
else:
|
182 |
+
self.audio_out_devices_cb.clear()
|
183 |
+
self.audio_out_devices_cb.addItems(output_devices)
|
184 |
+
self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
|
185 |
+
|
186 |
+
self.set_audio_device()
|
187 |
+
|
188 |
+
def set_audio_device(self):
|
189 |
+
|
190 |
+
output_device = self.audio_out_devices_cb.currentText()
|
191 |
+
if output_device == "None":
|
192 |
+
output_device = None
|
193 |
+
|
194 |
+
# If None, sounddevice queries portaudio
|
195 |
+
sd.default.device = (self.audio_in_device, output_device)
|
196 |
+
|
197 |
+
def play(self, wav, sample_rate):
|
198 |
+
try:
|
199 |
+
sd.stop()
|
200 |
+
sd.play(wav, sample_rate)
|
201 |
+
except Exception as e:
|
202 |
+
print(e)
|
203 |
+
self.log("Error in audio playback. Try selecting a different audio output device.")
|
204 |
+
self.log("Your device must be connected before you start the toolbox.")
|
205 |
+
|
206 |
+
def stop(self):
|
207 |
+
sd.stop()
|
208 |
+
|
209 |
+
def record_one(self, sample_rate, duration):
|
210 |
+
self.record_button.setText("Recording...")
|
211 |
+
self.record_button.setDisabled(True)
|
212 |
+
|
213 |
+
self.log("Recording %d seconds of audio" % duration)
|
214 |
+
sd.stop()
|
215 |
+
try:
|
216 |
+
wav = sd.rec(duration * sample_rate, sample_rate, 1)
|
217 |
+
except Exception as e:
|
218 |
+
print(e)
|
219 |
+
self.log("Could not record anything. Is your recording device enabled?")
|
220 |
+
self.log("Your device must be connected before you start the toolbox.")
|
221 |
+
return None
|
222 |
+
|
223 |
+
for i in np.arange(0, duration, 0.1):
|
224 |
+
self.set_loading(i, duration)
|
225 |
+
sleep(0.1)
|
226 |
+
self.set_loading(duration, duration)
|
227 |
+
sd.wait()
|
228 |
+
|
229 |
+
self.log("Done recording.")
|
230 |
+
self.record_button.setText("Record")
|
231 |
+
self.record_button.setDisabled(False)
|
232 |
+
|
233 |
+
return wav.squeeze()
|
234 |
+
|
235 |
+
@property
|
236 |
+
def current_dataset_name(self):
|
237 |
+
return self.dataset_box.currentText()
|
238 |
+
|
239 |
+
@property
|
240 |
+
def current_speaker_name(self):
|
241 |
+
return self.speaker_box.currentText()
|
242 |
+
|
243 |
+
@property
|
244 |
+
def current_utterance_name(self):
|
245 |
+
return self.utterance_box.currentText()
|
246 |
+
|
247 |
+
def browse_file(self):
|
248 |
+
fpath = QFileDialog().getOpenFileName(
|
249 |
+
parent=self,
|
250 |
+
caption="Select an audio file",
|
251 |
+
filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
|
252 |
+
)
|
253 |
+
return Path(fpath[0]) if fpath[0] != "" else ""
|
254 |
+
|
255 |
+
@staticmethod
|
256 |
+
def repopulate_box(box, items, random=False):
|
257 |
+
"""
|
258 |
+
Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
|
259 |
+
data to the items
|
260 |
+
"""
|
261 |
+
box.blockSignals(True)
|
262 |
+
box.clear()
|
263 |
+
for item in items:
|
264 |
+
item = list(item) if isinstance(item, tuple) else [item]
|
265 |
+
box.addItem(str(item[0]), *item[1:])
|
266 |
+
if len(items) > 0:
|
267 |
+
box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
|
268 |
+
box.setDisabled(len(items) == 0)
|
269 |
+
box.blockSignals(False)
|
270 |
+
|
271 |
+
def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
|
272 |
+
random=True):
|
273 |
+
# Select a random dataset
|
274 |
+
if level <= 0:
|
275 |
+
if datasets_root is not None:
|
276 |
+
datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
|
277 |
+
datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
|
278 |
+
self.browser_load_button.setDisabled(len(datasets) == 0)
|
279 |
+
if datasets_root is None or len(datasets) == 0:
|
280 |
+
msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
|
281 |
+
if datasets_root is None else "o not have any of the recognized datasets" \
|
282 |
+
" in %s" % datasets_root)
|
283 |
+
self.log(msg)
|
284 |
+
msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
|
285 |
+
"can still use the toolbox by recording samples yourself." % \
|
286 |
+
("\n\t".join(recognized_datasets))
|
287 |
+
print(msg, file=sys.stderr)
|
288 |
+
|
289 |
+
self.random_utterance_button.setDisabled(True)
|
290 |
+
self.random_speaker_button.setDisabled(True)
|
291 |
+
self.random_dataset_button.setDisabled(True)
|
292 |
+
self.utterance_box.setDisabled(True)
|
293 |
+
self.speaker_box.setDisabled(True)
|
294 |
+
self.dataset_box.setDisabled(True)
|
295 |
+
self.browser_load_button.setDisabled(True)
|
296 |
+
self.auto_next_checkbox.setDisabled(True)
|
297 |
+
return
|
298 |
+
self.repopulate_box(self.dataset_box, datasets, random)
|
299 |
+
|
300 |
+
# Select a random speaker
|
301 |
+
if level <= 1:
|
302 |
+
speakers_root = datasets_root.joinpath(self.current_dataset_name)
|
303 |
+
speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
|
304 |
+
self.repopulate_box(self.speaker_box, speaker_names, random)
|
305 |
+
|
306 |
+
# Select a random utterance
|
307 |
+
if level <= 2:
|
308 |
+
utterances_root = datasets_root.joinpath(
|
309 |
+
self.current_dataset_name,
|
310 |
+
self.current_speaker_name
|
311 |
+
)
|
312 |
+
utterances = []
|
313 |
+
for extension in ['mp3', 'flac', 'wav', 'm4a']:
|
314 |
+
utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
|
315 |
+
utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
|
316 |
+
self.repopulate_box(self.utterance_box, utterances, random)
|
317 |
+
|
318 |
+
def browser_select_next(self):
|
319 |
+
index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
|
320 |
+
self.utterance_box.setCurrentIndex(index)
|
321 |
+
|
322 |
+
@property
|
323 |
+
def current_encoder_fpath(self):
|
324 |
+
return self.encoder_box.itemData(self.encoder_box.currentIndex())
|
325 |
+
|
326 |
+
@property
|
327 |
+
def current_synthesizer_fpath(self):
|
328 |
+
return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
|
329 |
+
|
330 |
+
@property
|
331 |
+
def current_vocoder_fpath(self):
|
332 |
+
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
333 |
+
|
334 |
+
def populate_models(self, models_dir: Path):
|
335 |
+
# Encoder
|
336 |
+
encoder_fpaths = list(models_dir.glob("*/encoder.pt"))
|
337 |
+
if len(encoder_fpaths) == 0:
|
338 |
+
raise Exception("No encoder models found in %s" % models_dir)
|
339 |
+
self.repopulate_box(self.encoder_box, [(f.parent.name, f) for f in encoder_fpaths])
|
340 |
+
|
341 |
+
# Synthesizer
|
342 |
+
synthesizer_fpaths = list(models_dir.glob("*/synthesizer.pt"))
|
343 |
+
if len(synthesizer_fpaths) == 0:
|
344 |
+
raise Exception("No synthesizer models found in %s" % models_dir)
|
345 |
+
self.repopulate_box(self.synthesizer_box, [(f.parent.name, f) for f in synthesizer_fpaths])
|
346 |
+
|
347 |
+
# Vocoder
|
348 |
+
vocoder_fpaths = list(models_dir.glob("*/vocoder.pt"))
|
349 |
+
vocoder_items = [(f.parent.name, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
|
350 |
+
self.repopulate_box(self.vocoder_box, vocoder_items)
|
351 |
+
|
352 |
+
@property
|
353 |
+
def selected_utterance(self):
|
354 |
+
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
355 |
+
|
356 |
+
def register_utterance(self, utterance: Utterance):
|
357 |
+
self.utterance_history.blockSignals(True)
|
358 |
+
self.utterance_history.insertItem(0, utterance.name, utterance)
|
359 |
+
self.utterance_history.setCurrentIndex(0)
|
360 |
+
self.utterance_history.blockSignals(False)
|
361 |
+
|
362 |
+
if len(self.utterance_history) > self.max_saved_utterances:
|
363 |
+
self.utterance_history.removeItem(self.max_saved_utterances)
|
364 |
+
|
365 |
+
self.play_button.setDisabled(False)
|
366 |
+
self.generate_button.setDisabled(False)
|
367 |
+
self.synthesize_button.setDisabled(False)
|
368 |
+
|
369 |
+
def log(self, line, mode="newline"):
|
370 |
+
if mode == "newline":
|
371 |
+
self.logs.append(line)
|
372 |
+
if len(self.logs) > self.max_log_lines:
|
373 |
+
del self.logs[0]
|
374 |
+
elif mode == "append":
|
375 |
+
self.logs[-1] += line
|
376 |
+
elif mode == "overwrite":
|
377 |
+
self.logs[-1] = line
|
378 |
+
log_text = '\n'.join(self.logs)
|
379 |
+
|
380 |
+
self.log_window.setText(log_text)
|
381 |
+
self.app.processEvents()
|
382 |
+
|
383 |
+
def set_loading(self, value, maximum=1):
|
384 |
+
self.loading_bar.setValue(value * 100)
|
385 |
+
self.loading_bar.setMaximum(maximum * 100)
|
386 |
+
self.loading_bar.setTextVisible(value != 0)
|
387 |
+
self.app.processEvents()
|
388 |
+
|
389 |
+
def populate_gen_options(self, seed, trim_silences):
|
390 |
+
if seed is not None:
|
391 |
+
self.random_seed_checkbox.setChecked(True)
|
392 |
+
self.seed_textbox.setText(str(seed))
|
393 |
+
self.seed_textbox.setEnabled(True)
|
394 |
+
else:
|
395 |
+
self.random_seed_checkbox.setChecked(False)
|
396 |
+
self.seed_textbox.setText(str(0))
|
397 |
+
self.seed_textbox.setEnabled(False)
|
398 |
+
|
399 |
+
if not trim_silences:
|
400 |
+
self.trim_silences_checkbox.setChecked(False)
|
401 |
+
self.trim_silences_checkbox.setDisabled(True)
|
402 |
+
|
403 |
+
def update_seed_textbox(self):
|
404 |
+
if self.random_seed_checkbox.isChecked():
|
405 |
+
self.seed_textbox.setEnabled(True)
|
406 |
+
else:
|
407 |
+
self.seed_textbox.setEnabled(False)
|
408 |
+
|
409 |
+
def reset_interface(self):
|
410 |
+
self.draw_embed(None, None, "current")
|
411 |
+
self.draw_embed(None, None, "generated")
|
412 |
+
self.draw_spec(None, "current")
|
413 |
+
self.draw_spec(None, "generated")
|
414 |
+
self.draw_umap_projections(set())
|
415 |
+
self.set_loading(0)
|
416 |
+
self.play_button.setDisabled(True)
|
417 |
+
self.generate_button.setDisabled(True)
|
418 |
+
self.synthesize_button.setDisabled(True)
|
419 |
+
self.vocode_button.setDisabled(True)
|
420 |
+
self.replay_wav_button.setDisabled(True)
|
421 |
+
self.export_wav_button.setDisabled(True)
|
422 |
+
[self.log("") for _ in range(self.max_log_lines)]
|
423 |
+
|
424 |
+
def __init__(self):
|
425 |
+
## Initialize the application
|
426 |
+
self.app = QApplication(sys.argv)
|
427 |
+
super().__init__(None)
|
428 |
+
self.setWindowTitle("SV2TTS toolbox")
|
429 |
+
|
430 |
+
|
431 |
+
## Main layouts
|
432 |
+
# Root
|
433 |
+
root_layout = QGridLayout()
|
434 |
+
self.setLayout(root_layout)
|
435 |
+
|
436 |
+
# Browser
|
437 |
+
browser_layout = QGridLayout()
|
438 |
+
root_layout.addLayout(browser_layout, 0, 0, 1, 2)
|
439 |
+
|
440 |
+
# Generation
|
441 |
+
gen_layout = QVBoxLayout()
|
442 |
+
root_layout.addLayout(gen_layout, 0, 2, 1, 2)
|
443 |
+
|
444 |
+
# Projections
|
445 |
+
self.projections_layout = QVBoxLayout()
|
446 |
+
root_layout.addLayout(self.projections_layout, 1, 0, 1, 1)
|
447 |
+
|
448 |
+
# Visualizations
|
449 |
+
vis_layout = QVBoxLayout()
|
450 |
+
root_layout.addLayout(vis_layout, 1, 1, 1, 3)
|
451 |
+
|
452 |
+
|
453 |
+
## Projections
|
454 |
+
# UMap
|
455 |
+
fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
|
456 |
+
fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
|
457 |
+
self.projections_layout.addWidget(FigureCanvas(fig))
|
458 |
+
self.umap_hot = False
|
459 |
+
self.clear_button = QPushButton("Clear")
|
460 |
+
self.projections_layout.addWidget(self.clear_button)
|
461 |
+
|
462 |
+
|
463 |
+
## Browser
|
464 |
+
# Dataset, speaker and utterance selection
|
465 |
+
i = 0
|
466 |
+
self.dataset_box = QComboBox()
|
467 |
+
browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
|
468 |
+
browser_layout.addWidget(self.dataset_box, i + 1, 0)
|
469 |
+
self.speaker_box = QComboBox()
|
470 |
+
browser_layout.addWidget(QLabel("<b>Speaker</b>"), i, 1)
|
471 |
+
browser_layout.addWidget(self.speaker_box, i + 1, 1)
|
472 |
+
self.utterance_box = QComboBox()
|
473 |
+
browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
|
474 |
+
browser_layout.addWidget(self.utterance_box, i + 1, 2)
|
475 |
+
self.browser_load_button = QPushButton("Load")
|
476 |
+
browser_layout.addWidget(self.browser_load_button, i + 1, 3)
|
477 |
+
i += 2
|
478 |
+
|
479 |
+
# Random buttons
|
480 |
+
self.random_dataset_button = QPushButton("Random")
|
481 |
+
browser_layout.addWidget(self.random_dataset_button, i, 0)
|
482 |
+
self.random_speaker_button = QPushButton("Random")
|
483 |
+
browser_layout.addWidget(self.random_speaker_button, i, 1)
|
484 |
+
self.random_utterance_button = QPushButton("Random")
|
485 |
+
browser_layout.addWidget(self.random_utterance_button, i, 2)
|
486 |
+
self.auto_next_checkbox = QCheckBox("Auto select next")
|
487 |
+
self.auto_next_checkbox.setChecked(True)
|
488 |
+
browser_layout.addWidget(self.auto_next_checkbox, i, 3)
|
489 |
+
i += 1
|
490 |
+
|
491 |
+
# Utterance box
|
492 |
+
browser_layout.addWidget(QLabel("<b>Use embedding from:</b>"), i, 0)
|
493 |
+
self.utterance_history = QComboBox()
|
494 |
+
browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
|
495 |
+
i += 1
|
496 |
+
|
497 |
+
# Random & next utterance buttons
|
498 |
+
self.browser_browse_button = QPushButton("Browse")
|
499 |
+
browser_layout.addWidget(self.browser_browse_button, i, 0)
|
500 |
+
self.record_button = QPushButton("Record")
|
501 |
+
browser_layout.addWidget(self.record_button, i, 1)
|
502 |
+
self.play_button = QPushButton("Play")
|
503 |
+
browser_layout.addWidget(self.play_button, i, 2)
|
504 |
+
self.stop_button = QPushButton("Stop")
|
505 |
+
browser_layout.addWidget(self.stop_button, i, 3)
|
506 |
+
i += 1
|
507 |
+
|
508 |
+
|
509 |
+
# Model and audio output selection
|
510 |
+
self.encoder_box = QComboBox()
|
511 |
+
browser_layout.addWidget(QLabel("<b>Encoder</b>"), i, 0)
|
512 |
+
browser_layout.addWidget(self.encoder_box, i + 1, 0)
|
513 |
+
self.synthesizer_box = QComboBox()
|
514 |
+
browser_layout.addWidget(QLabel("<b>Synthesizer</b>"), i, 1)
|
515 |
+
browser_layout.addWidget(self.synthesizer_box, i + 1, 1)
|
516 |
+
self.vocoder_box = QComboBox()
|
517 |
+
browser_layout.addWidget(QLabel("<b>Vocoder</b>"), i, 2)
|
518 |
+
browser_layout.addWidget(self.vocoder_box, i + 1, 2)
|
519 |
+
|
520 |
+
self.audio_out_devices_cb=QComboBox()
|
521 |
+
browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 3)
|
522 |
+
browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3)
|
523 |
+
i += 2
|
524 |
+
|
525 |
+
#Replay & Save Audio
|
526 |
+
browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
527 |
+
self.waves_cb = QComboBox()
|
528 |
+
self.waves_cb_model = QStringListModel()
|
529 |
+
self.waves_cb.setModel(self.waves_cb_model)
|
530 |
+
self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
|
531 |
+
browser_layout.addWidget(self.waves_cb, i, 1)
|
532 |
+
self.replay_wav_button = QPushButton("Replay")
|
533 |
+
self.replay_wav_button.setToolTip("Replay last generated vocoder")
|
534 |
+
browser_layout.addWidget(self.replay_wav_button, i, 2)
|
535 |
+
self.export_wav_button = QPushButton("Export")
|
536 |
+
self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
|
537 |
+
browser_layout.addWidget(self.export_wav_button, i, 3)
|
538 |
+
i += 1
|
539 |
+
|
540 |
+
|
541 |
+
## Embed & spectrograms
|
542 |
+
vis_layout.addStretch()
|
543 |
+
|
544 |
+
gridspec_kw = {"width_ratios": [1, 4]}
|
545 |
+
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
546 |
+
gridspec_kw=gridspec_kw)
|
547 |
+
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
548 |
+
vis_layout.addWidget(FigureCanvas(fig))
|
549 |
+
|
550 |
+
fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
551 |
+
gridspec_kw=gridspec_kw)
|
552 |
+
fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
|
553 |
+
vis_layout.addWidget(FigureCanvas(fig))
|
554 |
+
|
555 |
+
for ax in self.current_ax.tolist() + self.gen_ax.tolist():
|
556 |
+
ax.set_facecolor("#F0F0F0")
|
557 |
+
for side in ["top", "right", "bottom", "left"]:
|
558 |
+
ax.spines[side].set_visible(False)
|
559 |
+
|
560 |
+
|
561 |
+
## Generation
|
562 |
+
self.text_prompt = QPlainTextEdit(default_text)
|
563 |
+
gen_layout.addWidget(self.text_prompt, stretch=1)
|
564 |
+
|
565 |
+
self.generate_button = QPushButton("Synthesize and vocode")
|
566 |
+
gen_layout.addWidget(self.generate_button)
|
567 |
+
|
568 |
+
layout = QHBoxLayout()
|
569 |
+
self.synthesize_button = QPushButton("Synthesize only")
|
570 |
+
layout.addWidget(self.synthesize_button)
|
571 |
+
self.vocode_button = QPushButton("Vocode only")
|
572 |
+
layout.addWidget(self.vocode_button)
|
573 |
+
gen_layout.addLayout(layout)
|
574 |
+
|
575 |
+
layout_seed = QGridLayout()
|
576 |
+
self.random_seed_checkbox = QCheckBox("Random seed:")
|
577 |
+
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
578 |
+
layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
|
579 |
+
self.seed_textbox = QLineEdit()
|
580 |
+
self.seed_textbox.setMaximumWidth(80)
|
581 |
+
layout_seed.addWidget(self.seed_textbox, 0, 1)
|
582 |
+
self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
|
583 |
+
self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
|
584 |
+
" This feature requires `webrtcvad` to be installed.")
|
585 |
+
layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
|
586 |
+
gen_layout.addLayout(layout_seed)
|
587 |
+
|
588 |
+
self.loading_bar = QProgressBar()
|
589 |
+
gen_layout.addWidget(self.loading_bar)
|
590 |
+
|
591 |
+
self.log_window = QLabel()
|
592 |
+
self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
|
593 |
+
gen_layout.addWidget(self.log_window)
|
594 |
+
self.logs = []
|
595 |
+
gen_layout.addStretch()
|
596 |
+
|
597 |
+
|
598 |
+
## Set the size of the window and of the elements
|
599 |
+
max_size = QDesktopWidget().availableGeometry(self).size() * 0.8
|
600 |
+
self.resize(max_size)
|
601 |
+
|
602 |
+
## Finalize the display
|
603 |
+
self.reset_interface()
|
604 |
+
self.show()
|
605 |
+
|
606 |
+
def start(self):
|
607 |
+
self.app.exec_()
|
toolbox/utterance.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
|
3 |
+
Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
|
4 |
+
Utterance.__eq__ = lambda x, y: x.name == y.name
|
5 |
+
Utterance.__hash__ = lambda x: hash(x.name)
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (157 Bytes). View file
|
|
utils/__pycache__/argutils.cpython-37.pyc
ADDED
Binary file (1.69 kB). View file
|
|
utils/__pycache__/default_models.cpython-37.pyc
ADDED
Binary file (2.26 kB). View file
|
|
utils/argutils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
_type_priorities = [ # In decreasing order
|
6 |
+
Path,
|
7 |
+
str,
|
8 |
+
int,
|
9 |
+
float,
|
10 |
+
bool,
|
11 |
+
]
|
12 |
+
|
13 |
+
def _priority(o):
|
14 |
+
p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None)
|
15 |
+
if p is not None:
|
16 |
+
return p
|
17 |
+
p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None)
|
18 |
+
if p is not None:
|
19 |
+
return p
|
20 |
+
return len(_type_priorities)
|
21 |
+
|
22 |
+
def print_args(args: argparse.Namespace, parser=None):
|
23 |
+
args = vars(args)
|
24 |
+
if parser is None:
|
25 |
+
priorities = list(map(_priority, args.values()))
|
26 |
+
else:
|
27 |
+
all_params = [a.dest for g in parser._action_groups for a in g._group_actions ]
|
28 |
+
priority = lambda p: all_params.index(p) if p in all_params else len(all_params)
|
29 |
+
priorities = list(map(priority, args.keys()))
|
30 |
+
|
31 |
+
pad = max(map(len, args.keys())) + 3
|
32 |
+
indices = np.lexsort((list(args.keys()), priorities))
|
33 |
+
items = list(args.items())
|
34 |
+
|
35 |
+
print("Arguments:")
|
36 |
+
for i in indices:
|
37 |
+
param, value = items[i]
|
38 |
+
print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value))
|
39 |
+
print("")
|
40 |
+
|
utils/default_models.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import urllib.request
|
2 |
+
from pathlib import Path
|
3 |
+
from threading import Thread
|
4 |
+
from urllib.error import HTTPError
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
default_models = {
|
10 |
+
"encoder": ("https://drive.google.com/uc?export=download&id=1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1", 17090379),
|
11 |
+
"synthesizer": ("https://drive.google.com/u/0/uc?id=1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s&export=download&confirm=t", 370554559),
|
12 |
+
"vocoder": ("https://drive.google.com/uc?export=download&id=1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu", 53845290),
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class DownloadProgressBar(tqdm):
|
17 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
18 |
+
if tsize is not None:
|
19 |
+
self.total = tsize
|
20 |
+
self.update(b * bsize - self.n)
|
21 |
+
|
22 |
+
|
23 |
+
def download(url: str, target: Path, bar_pos=0):
|
24 |
+
# Ensure the directory exists
|
25 |
+
target.parent.mkdir(exist_ok=True, parents=True)
|
26 |
+
|
27 |
+
desc = f"Downloading {target.name}"
|
28 |
+
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=desc, position=bar_pos, leave=False) as t:
|
29 |
+
try:
|
30 |
+
urllib.request.urlretrieve(url, filename=target, reporthook=t.update_to)
|
31 |
+
except HTTPError:
|
32 |
+
return
|
33 |
+
|
34 |
+
|
35 |
+
def ensure_default_models(models_dir: Path):
|
36 |
+
# Define download tasks
|
37 |
+
jobs = []
|
38 |
+
for model_name, (url, size) in default_models.items():
|
39 |
+
target_path = models_dir / "default" / f"{model_name}.pt"
|
40 |
+
if target_path.exists():
|
41 |
+
if target_path.stat().st_size != size:
|
42 |
+
print(f"File {target_path} is not of expected size, redownloading...")
|
43 |
+
else:
|
44 |
+
continue
|
45 |
+
|
46 |
+
thread = Thread(target=download, args=(url, target_path, len(jobs)))
|
47 |
+
thread.start()
|
48 |
+
jobs.append((thread, target_path, size))
|
49 |
+
|
50 |
+
# Run and join threads
|
51 |
+
for thread, target_path, size in jobs:
|
52 |
+
thread.join()
|
53 |
+
|
54 |
+
assert target_path.exists() and target_path.stat().st_size == size, \
|
55 |
+
f"Download for {target_path.name} failed. You may download models manually instead.\n" \
|
56 |
+
f"https://drive.google.com/drive/folders/1fU6umc5uQAVR2udZdHX-lDgXYzTyqG_j"
|
utils/logmmse.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The MIT License (MIT)
|
2 |
+
#
|
3 |
+
# Copyright (c) 2015 braindead
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
#
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
#
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
#
|
23 |
+
#
|
24 |
+
# This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
|
25 |
+
# simply modified the interface to meet my needs.
|
26 |
+
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
import math
|
30 |
+
from scipy.special import expn
|
31 |
+
from collections import namedtuple
|
32 |
+
|
33 |
+
NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
|
34 |
+
|
35 |
+
|
36 |
+
def profile_noise(noise, sampling_rate, window_size=0):
|
37 |
+
"""
|
38 |
+
Creates a profile of the noise in a given waveform.
|
39 |
+
|
40 |
+
:param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
|
41 |
+
:param sampling_rate: the sampling rate of the audio
|
42 |
+
:param window_size: the size of the window the logmmse algorithm operates on. A default value
|
43 |
+
will be picked if left as 0.
|
44 |
+
:return: a NoiseProfile object
|
45 |
+
"""
|
46 |
+
noise, dtype = to_float(noise)
|
47 |
+
noise += np.finfo(np.float64).eps
|
48 |
+
|
49 |
+
if window_size == 0:
|
50 |
+
window_size = int(math.floor(0.02 * sampling_rate))
|
51 |
+
|
52 |
+
if window_size % 2 == 1:
|
53 |
+
window_size = window_size + 1
|
54 |
+
|
55 |
+
perc = 50
|
56 |
+
len1 = int(math.floor(window_size * perc / 100))
|
57 |
+
len2 = int(window_size - len1)
|
58 |
+
|
59 |
+
win = np.hanning(window_size)
|
60 |
+
win = win * len2 / np.sum(win)
|
61 |
+
n_fft = 2 * window_size
|
62 |
+
|
63 |
+
noise_mean = np.zeros(n_fft)
|
64 |
+
n_frames = len(noise) // window_size
|
65 |
+
for j in range(0, window_size * n_frames, window_size):
|
66 |
+
noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
|
67 |
+
noise_mu2 = (noise_mean / n_frames) ** 2
|
68 |
+
|
69 |
+
return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
|
70 |
+
|
71 |
+
|
72 |
+
def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
|
73 |
+
"""
|
74 |
+
Cleans the noise from a speech waveform given a noise profile. The waveform must have the
|
75 |
+
same sampling rate as the one used to create the noise profile.
|
76 |
+
|
77 |
+
:param wav: a speech waveform as a numpy array of floats or ints.
|
78 |
+
:param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
|
79 |
+
the same) waveform.
|
80 |
+
:param eta: voice threshold for noise update. While the voice activation detection value is
|
81 |
+
below this threshold, the noise profile will be continuously updated throughout the audio.
|
82 |
+
Set to 0 to disable updating the noise profile.
|
83 |
+
:return: the clean wav as a numpy array of floats or ints of the same length.
|
84 |
+
"""
|
85 |
+
wav, dtype = to_float(wav)
|
86 |
+
wav += np.finfo(np.float64).eps
|
87 |
+
p = noise_profile
|
88 |
+
|
89 |
+
nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
|
90 |
+
x_final = np.zeros(nframes * p.len2)
|
91 |
+
|
92 |
+
aa = 0.98
|
93 |
+
mu = 0.98
|
94 |
+
ksi_min = 10 ** (-25 / 10)
|
95 |
+
|
96 |
+
x_old = np.zeros(p.len1)
|
97 |
+
xk_prev = np.zeros(p.len1)
|
98 |
+
noise_mu2 = p.noise_mu2
|
99 |
+
for k in range(0, nframes * p.len2, p.len2):
|
100 |
+
insign = p.win * wav[k:k + p.window_size]
|
101 |
+
|
102 |
+
spec = np.fft.fft(insign, p.n_fft, axis=0)
|
103 |
+
sig = np.absolute(spec)
|
104 |
+
sig2 = sig ** 2
|
105 |
+
|
106 |
+
gammak = np.minimum(sig2 / noise_mu2, 40)
|
107 |
+
|
108 |
+
if xk_prev.all() == 0:
|
109 |
+
ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
|
110 |
+
else:
|
111 |
+
ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
|
112 |
+
ksi = np.maximum(ksi_min, ksi)
|
113 |
+
|
114 |
+
log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
|
115 |
+
vad_decision = np.sum(log_sigma_k) / p.window_size
|
116 |
+
if vad_decision < eta:
|
117 |
+
noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
|
118 |
+
|
119 |
+
a = ksi / (1 + ksi)
|
120 |
+
vk = a * gammak
|
121 |
+
ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
|
122 |
+
hw = a * np.exp(ei_vk)
|
123 |
+
sig = sig * hw
|
124 |
+
xk_prev = sig ** 2
|
125 |
+
xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
|
126 |
+
xi_w = np.real(xi_w)
|
127 |
+
|
128 |
+
x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
|
129 |
+
x_old = xi_w[p.len1:p.window_size]
|
130 |
+
|
131 |
+
output = from_float(x_final, dtype)
|
132 |
+
output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
|
133 |
+
return output
|
134 |
+
|
135 |
+
|
136 |
+
## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
|
137 |
+
## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
|
138 |
+
## webrctvad
|
139 |
+
# def vad(wav, sampling_rate, eta=0.15, window_size=0):
|
140 |
+
# """
|
141 |
+
# TODO: fix doc
|
142 |
+
# Creates a profile of the noise in a given waveform.
|
143 |
+
#
|
144 |
+
# :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
|
145 |
+
# :param sampling_rate: the sampling rate of the audio
|
146 |
+
# :param window_size: the size of the window the logmmse algorithm operates on. A default value
|
147 |
+
# will be picked if left as 0.
|
148 |
+
# :param eta: voice threshold for noise update. While the voice activation detection value is
|
149 |
+
# below this threshold, the noise profile will be continuously updated throughout the audio.
|
150 |
+
# Set to 0 to disable updating the noise profile.
|
151 |
+
# """
|
152 |
+
# wav, dtype = to_float(wav)
|
153 |
+
# wav += np.finfo(np.float64).eps
|
154 |
+
#
|
155 |
+
# if window_size == 0:
|
156 |
+
# window_size = int(math.floor(0.02 * sampling_rate))
|
157 |
+
#
|
158 |
+
# if window_size % 2 == 1:
|
159 |
+
# window_size = window_size + 1
|
160 |
+
#
|
161 |
+
# perc = 50
|
162 |
+
# len1 = int(math.floor(window_size * perc / 100))
|
163 |
+
# len2 = int(window_size - len1)
|
164 |
+
#
|
165 |
+
# win = np.hanning(window_size)
|
166 |
+
# win = win * len2 / np.sum(win)
|
167 |
+
# n_fft = 2 * window_size
|
168 |
+
#
|
169 |
+
# wav_mean = np.zeros(n_fft)
|
170 |
+
# n_frames = len(wav) // window_size
|
171 |
+
# for j in range(0, window_size * n_frames, window_size):
|
172 |
+
# wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
|
173 |
+
# noise_mu2 = (wav_mean / n_frames) ** 2
|
174 |
+
#
|
175 |
+
# wav, dtype = to_float(wav)
|
176 |
+
# wav += np.finfo(np.float64).eps
|
177 |
+
#
|
178 |
+
# nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
|
179 |
+
# vad = np.zeros(nframes * len2, dtype=np.bool)
|
180 |
+
#
|
181 |
+
# aa = 0.98
|
182 |
+
# mu = 0.98
|
183 |
+
# ksi_min = 10 ** (-25 / 10)
|
184 |
+
#
|
185 |
+
# xk_prev = np.zeros(len1)
|
186 |
+
# noise_mu2 = noise_mu2
|
187 |
+
# for k in range(0, nframes * len2, len2):
|
188 |
+
# insign = win * wav[k:k + window_size]
|
189 |
+
#
|
190 |
+
# spec = np.fft.fft(insign, n_fft, axis=0)
|
191 |
+
# sig = np.absolute(spec)
|
192 |
+
# sig2 = sig ** 2
|
193 |
+
#
|
194 |
+
# gammak = np.minimum(sig2 / noise_mu2, 40)
|
195 |
+
#
|
196 |
+
# if xk_prev.all() == 0:
|
197 |
+
# ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
|
198 |
+
# else:
|
199 |
+
# ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
|
200 |
+
# ksi = np.maximum(ksi_min, ksi)
|
201 |
+
#
|
202 |
+
# log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
|
203 |
+
# vad_decision = np.sum(log_sigma_k) / window_size
|
204 |
+
# if vad_decision < eta:
|
205 |
+
# noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
|
206 |
+
# print(vad_decision)
|
207 |
+
#
|
208 |
+
# a = ksi / (1 + ksi)
|
209 |
+
# vk = a * gammak
|
210 |
+
# ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
|
211 |
+
# hw = a * np.exp(ei_vk)
|
212 |
+
# sig = sig * hw
|
213 |
+
# xk_prev = sig ** 2
|
214 |
+
#
|
215 |
+
# vad[k:k + len2] = vad_decision >= eta
|
216 |
+
#
|
217 |
+
# vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
|
218 |
+
# return vad
|
219 |
+
|
220 |
+
|
221 |
+
def to_float(_input):
|
222 |
+
if _input.dtype == np.float64:
|
223 |
+
return _input, _input.dtype
|
224 |
+
elif _input.dtype == np.float32:
|
225 |
+
return _input.astype(np.float64), _input.dtype
|
226 |
+
elif _input.dtype == np.uint8:
|
227 |
+
return (_input - 128) / 128., _input.dtype
|
228 |
+
elif _input.dtype == np.int16:
|
229 |
+
return _input / 32768., _input.dtype
|
230 |
+
elif _input.dtype == np.int32:
|
231 |
+
return _input / 2147483648., _input.dtype
|
232 |
+
raise ValueError('Unsupported wave file format')
|
233 |
+
|
234 |
+
|
235 |
+
def from_float(_input, dtype):
|
236 |
+
if dtype == np.float64:
|
237 |
+
return _input, np.float64
|
238 |
+
elif dtype == np.float32:
|
239 |
+
return _input.astype(np.float32)
|
240 |
+
elif dtype == np.uint8:
|
241 |
+
return ((_input * 128) + 128).astype(np.uint8)
|
242 |
+
elif dtype == np.int16:
|
243 |
+
return (_input * 32768).astype(np.int16)
|
244 |
+
elif dtype == np.int32:
|
245 |
+
print(_input)
|
246 |
+
return (_input * 2147483648).astype(np.int32)
|
247 |
+
raise ValueError('Unsupported wave file format')
|