File size: 4,136 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import numpy as np
from torch.utils.data import Dataset


class WaveRNNDataset(Dataset):
    """
    WaveRNN Dataset searchs for all the wav files under root path
    and converts them to acoustic features on the fly.
    """

    def __init__(self,
                 ap,
                 items,
                 seq_len,
                 hop_len,
                 pad,
                 mode,
                 mulaw,
                 is_training=True,
                 verbose=False,
                 ):

        self.ap = ap
        self.compute_feat = not isinstance(items[0], (tuple, list))
        self.item_list = items
        self.seq_len = seq_len
        self.hop_len = hop_len
        self.mel_len = seq_len // hop_len
        self.pad = pad
        self.mode = mode
        self.mulaw = mulaw
        self.is_training = is_training
        self.verbose = verbose

        assert self.seq_len % self.hop_len == 0

    def __len__(self):
        return len(self.item_list)

    def __getitem__(self, index):
        item = self.load_item(index)
        return item

    def load_item(self, index):
        """
        load (audio, feat) couple if feature_path is set
        else compute it on the fly
        """
        if self.compute_feat:

            wavpath = self.item_list[index]
            audio = self.ap.load_wav(wavpath)
            min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
            if audio.shape[0] < min_audio_len:
                print(" [!] Instance is too short! : {}".format(wavpath))
                audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
            mel = self.ap.melspectrogram(audio)

            if self.mode in ["gauss", "mold"]:
                x_input = audio
            elif isinstance(self.mode, int):
                x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
                           if self.mulaw else self.ap.quantize(audio, bits=self.mode))
            else:
                raise RuntimeError("Unknown dataset mode - ", self.mode)

        else:

            wavpath, feat_path = self.item_list[index]
            mel = np.load(feat_path.replace("/quant/", "/mel/"))

            if mel.shape[-1] < self.mel_len  + 2 * self.pad:
                print(" [!] Instance is too short! : {}".format(wavpath))
                self.item_list[index] = self.item_list[index + 1]
                feat_path = self.item_list[index]
                mel = np.load(feat_path.replace("/quant/", "/mel/"))
            if self.mode in ["gauss", "mold"]:
                x_input = self.ap.load_wav(wavpath)
            elif isinstance(self.mode, int):
                x_input = np.load(feat_path.replace("/mel/", "/quant/"))
            else:
                raise RuntimeError("Unknown dataset mode - ", self.mode)

        return mel, x_input, wavpath

    def collate(self, batch):
        mel_win = self.seq_len // self.hop_len + 2 * self.pad
        max_offsets = [x[0].shape[-1] -
                       (mel_win + 2 * self.pad) for x in batch]

        mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
        sig_offsets = [(offset + self.pad) *
                       self.hop_len for offset in mel_offsets]

        mels = [
            x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
            for i, x in enumerate(batch)
        ]

        coarse = [
            x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
            for i, x in enumerate(batch)
        ]

        mels = np.stack(mels).astype(np.float32)
        if self.mode in ["gauss", "mold"]:
            coarse = np.stack(coarse).astype(np.float32)
            coarse = torch.FloatTensor(coarse)
            x_input = coarse[:, : self.seq_len]
        elif isinstance(self.mode, int):
            coarse = np.stack(coarse).astype(np.int64)
            coarse = torch.LongTensor(coarse)
            x_input = (2 * coarse[:, : self.seq_len].float() /
                       (2 ** self.mode - 1.0) - 1.0)
        y_coarse = coarse[:, 1:]
        mels = torch.FloatTensor(mels)
        return x_input, mels, y_coarse