PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
72848ba
·
verified ·
1 Parent(s): 96ffdcd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. fairseq/fairseq/data/audio/__init__.py +93 -0
  3. fairseq/fairseq/data/audio/__pycache__/__init__.cpython-310.pyc +0 -0
  4. fairseq/fairseq/data/audio/__pycache__/audio_utils.cpython-310.pyc +0 -0
  5. fairseq/fairseq/data/audio/__pycache__/data_cfg.cpython-310.pyc +0 -0
  6. fairseq/fairseq/data/audio/__pycache__/frm_text_to_speech_dataset.cpython-310.pyc +0 -0
  7. fairseq/fairseq/data/audio/__pycache__/hubert_dataset.cpython-310.pyc +0 -0
  8. fairseq/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-310.pyc +0 -0
  9. fairseq/fairseq/data/audio/__pycache__/speech_to_speech_dataset.cpython-310.pyc +0 -0
  10. fairseq/fairseq/data/audio/__pycache__/speech_to_text_dataset.cpython-310.pyc +0 -0
  11. fairseq/fairseq/data/audio/__pycache__/text_to_speech_dataset.cpython-310.pyc +0 -0
  12. fairseq/fairseq/data/audio/audio_utils.py +389 -0
  13. fairseq/fairseq/data/audio/data_cfg.py +387 -0
  14. fairseq/fairseq/data/audio/dataset_transforms/__init__.py +53 -0
  15. fairseq/fairseq/data/audio/dataset_transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  16. fairseq/fairseq/data/audio/dataset_transforms/__pycache__/noisyoverlapaugment.cpython-310.pyc +0 -0
  17. fairseq/fairseq/data/audio/dataset_transforms/concataugment.py +61 -0
  18. fairseq/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py +105 -0
  19. fairseq/fairseq/data/audio/feature_transforms/__init__.py +43 -0
  20. fairseq/fairseq/data/audio/feature_transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  21. fairseq/fairseq/data/audio/feature_transforms/__pycache__/global_cmvn.cpython-310.pyc +0 -0
  22. fairseq/fairseq/data/audio/feature_transforms/__pycache__/specaugment.cpython-310.pyc +0 -0
  23. fairseq/fairseq/data/audio/feature_transforms/__pycache__/utterance_cmvn.cpython-310.pyc +0 -0
  24. fairseq/fairseq/data/audio/feature_transforms/delta_deltas.py +37 -0
  25. fairseq/fairseq/data/audio/feature_transforms/specaugment.py +131 -0
  26. fairseq/fairseq/data/audio/feature_transforms/utterance_cmvn.py +41 -0
  27. fairseq/fairseq/data/audio/frm_text_to_speech_dataset.py +205 -0
  28. fairseq/fairseq/data/audio/hubert_dataset.py +356 -0
  29. fairseq/fairseq/data/audio/multi_modality_dataset.py +267 -0
  30. fairseq/fairseq/data/audio/raw_audio_dataset.py +431 -0
  31. fairseq/fairseq/data/audio/speech_to_speech_dataset.py +379 -0
  32. fairseq/fairseq/data/audio/speech_to_text_joint_dataset.py +359 -0
  33. fairseq/fairseq/data/audio/text_to_speech_dataset.py +250 -0
  34. fairseq/fairseq/data/audio/waveform_transforms/__init__.py +48 -0
  35. fairseq/fairseq/data/audio/waveform_transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  36. fairseq/fairseq/data/audio/waveform_transforms/__pycache__/noiseaugment.cpython-310.pyc +0 -0
  37. fairseq/fairseq/data/audio/waveform_transforms/noiseaugment.py +201 -0
  38. fairseq/fairseq/data/data_utils_fast.cpp +0 -0
  39. fairseq/fairseq/data/encoders/__init__.py +29 -0
  40. fairseq/fairseq/data/encoders/__pycache__/__init__.cpython-310.pyc +0 -0
  41. fairseq/fairseq/data/encoders/__pycache__/byte_bpe.cpython-310.pyc +0 -0
  42. fairseq/fairseq/data/encoders/__pycache__/byte_utils.cpython-310.pyc +0 -0
  43. fairseq/fairseq/data/encoders/__pycache__/bytes.cpython-310.pyc +0 -0
  44. fairseq/fairseq/data/encoders/__pycache__/characters.cpython-310.pyc +0 -0
  45. fairseq/fairseq/data/encoders/__pycache__/fastbpe.cpython-310.pyc +0 -0
  46. fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-310.pyc +0 -0
  47. fairseq/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-310.pyc +0 -0
  48. fairseq/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-310.pyc +0 -0
  49. fairseq/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-310.pyc +0 -0
  50. fairseq/fairseq/data/encoders/byte_bpe.py +48 -0
.gitattributes CHANGED
@@ -41,3 +41,4 @@ fairseq/docs/fairseq.gif filter=lfs diff=lfs merge=lfs -text
41
  fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
42
  fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
43
  fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
 
 
41
  fairseq/examples/hubert/tests/6313-76958-0021.flac filter=lfs diff=lfs merge=lfs -text
42
  fairseq/examples/textless_nlp/speech-resynth/img/fig.png filter=lfs diff=lfs merge=lfs -text
43
  fairseq/fairseq/libbase.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
44
+ fairseq/fairseq/libnat.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
fairseq/fairseq/data/audio/__init__.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Optional
3
+ import importlib
4
+ import os
5
+ import numpy as np
6
+
7
+
8
+ class AudioTransform(ABC):
9
+ @classmethod
10
+ @abstractmethod
11
+ def from_config_dict(cls, config: Optional[Dict] = None):
12
+ pass
13
+
14
+
15
+ class CompositeAudioTransform(AudioTransform):
16
+ def _from_config_dict(
17
+ cls,
18
+ transform_type,
19
+ get_audio_transform,
20
+ composite_cls,
21
+ config=None,
22
+ return_empty=False,
23
+ ):
24
+ _config = {} if config is None else config
25
+ _transforms = _config.get(f"{transform_type}_transforms")
26
+
27
+ if _transforms is None:
28
+ if return_empty:
29
+ _transforms = []
30
+ else:
31
+ return None
32
+
33
+ transforms = [
34
+ get_audio_transform(_t).from_config_dict(_config.get(_t))
35
+ for _t in _transforms
36
+ ]
37
+ return composite_cls(transforms)
38
+
39
+ def __init__(self, transforms):
40
+ self.transforms = [t for t in transforms if t is not None]
41
+
42
+ def __call__(self, x):
43
+ for t in self.transforms:
44
+ x = t(x)
45
+ return x
46
+
47
+ def __repr__(self):
48
+ format_string = (
49
+ [self.__class__.__name__ + "("]
50
+ + [f" {t.__repr__()}" for t in self.transforms]
51
+ + [")"]
52
+ )
53
+ return "\n".join(format_string)
54
+
55
+
56
+ def register_audio_transform(name, cls_type, registry, class_names):
57
+ def register_audio_transform_cls(cls):
58
+ if name in registry:
59
+ raise ValueError(f"Cannot register duplicate transform ({name})")
60
+ if not issubclass(cls, cls_type):
61
+ raise ValueError(
62
+ f"Transform ({name}: {cls.__name__}) must extend "
63
+ f"{cls_type.__name__}"
64
+ )
65
+ if cls.__name__ in class_names:
66
+ raise ValueError(
67
+ f"Cannot register audio transform with duplicate "
68
+ f"class name ({cls.__name__})"
69
+ )
70
+ registry[name] = cls
71
+ class_names.add(cls.__name__)
72
+ return cls
73
+
74
+ return register_audio_transform_cls
75
+
76
+
77
+ def import_transforms(transforms_dir, transform_type):
78
+ for file in os.listdir(transforms_dir):
79
+ path = os.path.join(transforms_dir, file)
80
+ if (
81
+ not file.startswith("_")
82
+ and not file.startswith(".")
83
+ and (file.endswith(".py") or os.path.isdir(path))
84
+ ):
85
+ name = file[: file.find(".py")] if file.endswith(".py") else file
86
+ importlib.import_module(
87
+ f"fairseq.data.audio.{transform_type}_transforms." + name
88
+ )
89
+
90
+
91
+ # Utility fn for uniform numbers in transforms
92
+ def rand_uniform(a, b):
93
+ return np.random.uniform() * (b - a) + a
fairseq/fairseq/data/audio/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.61 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/audio_utils.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/data_cfg.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/frm_text_to_speech_dataset.cpython-310.pyc ADDED
Binary file (6.07 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/hubert_dataset.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/raw_audio_dataset.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/speech_to_speech_dataset.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/speech_to_text_dataset.cpython-310.pyc ADDED
Binary file (25.5 kB). View file
 
fairseq/fairseq/data/audio/__pycache__/text_to_speech_dataset.cpython-310.pyc ADDED
Binary file (8.36 kB). View file
 
fairseq/fairseq/data/audio/audio_utils.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import mmap
8
+ from pathlib import Path
9
+ import io
10
+ from typing import BinaryIO, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform
17
+
18
+ SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"}
19
+ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
20
+
21
+
22
+ def convert_waveform(
23
+ waveform: Union[np.ndarray, torch.Tensor],
24
+ sample_rate: int,
25
+ normalize_volume: bool = False,
26
+ to_mono: bool = False,
27
+ to_sample_rate: Optional[int] = None,
28
+ ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
29
+ """convert a waveform:
30
+ - to a target sample rate
31
+ - from multi-channel to mono channel
32
+ - volume normalization
33
+
34
+ Args:
35
+ waveform (numpy.ndarray or torch.Tensor): 2D original waveform
36
+ (channels x length)
37
+ sample_rate (int): original sample rate
38
+ normalize_volume (bool): perform volume normalization
39
+ to_mono (bool): convert to mono channel if having multiple channels
40
+ to_sample_rate (Optional[int]): target sample rate
41
+ Returns:
42
+ waveform (numpy.ndarray): converted 2D waveform (channels x length)
43
+ sample_rate (float): target sample rate
44
+ """
45
+ try:
46
+ import torchaudio.sox_effects as ta_sox
47
+ except ImportError:
48
+ raise ImportError("Please install torchaudio: pip install torchaudio")
49
+
50
+ effects = []
51
+ if normalize_volume:
52
+ effects.append(["gain", "-n"])
53
+ if to_sample_rate is not None and to_sample_rate != sample_rate:
54
+ effects.append(["rate", f"{to_sample_rate}"])
55
+ if to_mono and waveform.shape[0] > 1:
56
+ effects.append(["channels", "1"])
57
+ if len(effects) > 0:
58
+ is_np_input = isinstance(waveform, np.ndarray)
59
+ _waveform = torch.from_numpy(waveform) if is_np_input else waveform
60
+ converted, converted_sample_rate = ta_sox.apply_effects_tensor(
61
+ _waveform, sample_rate, effects
62
+ )
63
+ if is_np_input:
64
+ converted = converted.numpy()
65
+ return converted, converted_sample_rate
66
+ return waveform, sample_rate
67
+
68
+
69
+ def get_waveform(
70
+ path_or_fp: Union[str, BinaryIO],
71
+ normalization: bool = True,
72
+ mono: bool = True,
73
+ frames: int = -1,
74
+ start: int = 0,
75
+ always_2d: bool = True,
76
+ output_sample_rate: Optional[int] = None,
77
+ normalize_volume: bool = False,
78
+ waveform_transforms: Optional[CompositeAudioWaveformTransform] = None,
79
+ ) -> Tuple[np.ndarray, int]:
80
+ """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
81
+
82
+ Args:
83
+ path_or_fp (str or BinaryIO): the path or file-like object
84
+ normalization (bool): normalize values to [-1, 1] (Default: True)
85
+ mono (bool): convert multi-channel audio to mono-channel one
86
+ frames (int): the number of frames to read. (-1 for reading all)
87
+ start (int): Where to start reading. A negative value counts from the end.
88
+ always_2d (bool): always return 2D array even for mono-channel audios
89
+ output_sample_rate (Optional[int]): output sample rate
90
+ normalize_volume (bool): normalize volume
91
+ Returns:
92
+ waveform (numpy.ndarray): 1D or 2D waveform (channels x length)
93
+ sample_rate (float): sample rate
94
+ """
95
+ if isinstance(path_or_fp, str):
96
+ ext = Path(path_or_fp).suffix
97
+ if ext not in SF_AUDIO_FILE_EXTENSIONS:
98
+ raise ValueError(f"Unsupported audio format: {ext}")
99
+
100
+ try:
101
+ import soundfile as sf
102
+ except ImportError:
103
+ raise ImportError("Please install soundfile: pip install soundfile")
104
+
105
+ waveform, sample_rate = sf.read(
106
+ path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start
107
+ )
108
+ waveform = waveform.T # T x C -> C x T
109
+ waveform, sample_rate = convert_waveform(
110
+ waveform,
111
+ sample_rate,
112
+ normalize_volume=normalize_volume,
113
+ to_mono=mono,
114
+ to_sample_rate=output_sample_rate,
115
+ )
116
+
117
+ if not normalization:
118
+ waveform *= 2**15 # denormalized to 16-bit signed integers
119
+
120
+ if waveform_transforms is not None:
121
+ waveform, sample_rate = waveform_transforms(waveform, sample_rate)
122
+
123
+ if not always_2d:
124
+ waveform = waveform.squeeze(axis=0)
125
+
126
+ return waveform, sample_rate
127
+
128
+
129
+ def get_features_from_npy_or_audio(path, waveform_transforms=None):
130
+ ext = Path(path).suffix
131
+ if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
132
+ raise ValueError(f'Unsupported file format for "{path}"')
133
+ return (
134
+ np.load(path)
135
+ if ext == ".npy"
136
+ else get_fbank(path, waveform_transforms=waveform_transforms)
137
+ )
138
+
139
+
140
+ def get_features_or_waveform_from_stored_zip(
141
+ path,
142
+ byte_offset,
143
+ byte_size,
144
+ need_waveform=False,
145
+ use_sample_rate=None,
146
+ waveform_transforms=None,
147
+ ):
148
+ assert path.endswith(".zip")
149
+ data = read_from_stored_zip(path, byte_offset, byte_size)
150
+ f = io.BytesIO(data)
151
+ if is_npy_data(data):
152
+ features_or_waveform = np.load(f)
153
+ elif is_sf_audio_data(data):
154
+ features_or_waveform = (
155
+ get_waveform(
156
+ f,
157
+ always_2d=False,
158
+ output_sample_rate=use_sample_rate,
159
+ waveform_transforms=waveform_transforms,
160
+ )[0]
161
+ if need_waveform
162
+ else get_fbank(f, waveform_transforms=waveform_transforms)
163
+ )
164
+ else:
165
+ raise ValueError(f'Unknown file format for "{path}"')
166
+ return features_or_waveform
167
+
168
+
169
+ def get_features_or_waveform(
170
+ path: str, need_waveform=False, use_sample_rate=None, waveform_transforms=None
171
+ ):
172
+ """Get speech features from .npy file or waveform from .wav/.flac file.
173
+ The file may be inside an uncompressed ZIP file and is accessed via byte
174
+ offset and length.
175
+
176
+ Args:
177
+ path (str): File path in the format of "<.npy/.wav/.flac path>" or
178
+ "<zip path>:<byte offset>:<byte length>".
179
+ need_waveform (bool): return waveform instead of features.
180
+ use_sample_rate (int): change sample rate for the input wave file
181
+
182
+ Returns:
183
+ features_or_waveform (numpy.ndarray): speech features or waveform.
184
+ """
185
+ _path, slice_ptr = parse_path(path)
186
+ if len(slice_ptr) == 0:
187
+ if need_waveform:
188
+ return get_waveform(
189
+ _path,
190
+ always_2d=False,
191
+ output_sample_rate=use_sample_rate,
192
+ waveform_transforms=waveform_transforms,
193
+ )[0]
194
+ return get_features_from_npy_or_audio(
195
+ _path, waveform_transforms=waveform_transforms
196
+ )
197
+ elif len(slice_ptr) == 2:
198
+ features_or_waveform = get_features_or_waveform_from_stored_zip(
199
+ _path,
200
+ slice_ptr[0],
201
+ slice_ptr[1],
202
+ need_waveform=need_waveform,
203
+ use_sample_rate=use_sample_rate,
204
+ waveform_transforms=waveform_transforms,
205
+ )
206
+ else:
207
+ raise ValueError(f"Invalid path: {path}")
208
+
209
+ return features_or_waveform
210
+
211
+
212
+ def _get_kaldi_fbank(
213
+ waveform: np.ndarray, sample_rate: int, n_bins=80
214
+ ) -> Optional[np.ndarray]:
215
+ """Get mel-filter bank features via PyKaldi."""
216
+ try:
217
+ from kaldi.feat.fbank import Fbank, FbankOptions
218
+ from kaldi.feat.mel import MelBanksOptions
219
+ from kaldi.feat.window import FrameExtractionOptions
220
+ from kaldi.matrix import Vector
221
+
222
+ mel_opts = MelBanksOptions()
223
+ mel_opts.num_bins = n_bins
224
+ frame_opts = FrameExtractionOptions()
225
+ frame_opts.samp_freq = sample_rate
226
+ opts = FbankOptions()
227
+ opts.mel_opts = mel_opts
228
+ opts.frame_opts = frame_opts
229
+ fbank = Fbank(opts=opts)
230
+ features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy()
231
+ return features
232
+ except ImportError:
233
+ return None
234
+
235
+
236
+ def _get_torchaudio_fbank(
237
+ waveform: np.ndarray, sample_rate, n_bins=80
238
+ ) -> Optional[np.ndarray]:
239
+ """Get mel-filter bank features via TorchAudio."""
240
+ try:
241
+ import torchaudio.compliance.kaldi as ta_kaldi
242
+
243
+ waveform = torch.from_numpy(waveform)
244
+ features = ta_kaldi.fbank(
245
+ waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
246
+ )
247
+ return features.numpy()
248
+ except ImportError:
249
+ return None
250
+
251
+
252
+ def get_fbank(
253
+ path_or_fp: Union[str, BinaryIO], n_bins=80, waveform_transforms=None
254
+ ) -> np.ndarray:
255
+ """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
256
+ (faster CPP implementation) to TorchAudio (Python implementation). Note that
257
+ Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
258
+ waveform should not be normalized."""
259
+ waveform, sample_rate = get_waveform(
260
+ path_or_fp, normalization=False, waveform_transforms=waveform_transforms
261
+ )
262
+
263
+ features = _get_kaldi_fbank(waveform, sample_rate, n_bins)
264
+ if features is None:
265
+ features = _get_torchaudio_fbank(waveform, sample_rate, n_bins)
266
+ if features is None:
267
+ raise ImportError(
268
+ "Please install pyKaldi or torchaudio to enable "
269
+ "online filterbank feature extraction"
270
+ )
271
+
272
+ return features
273
+
274
+
275
+ def is_npy_data(data: bytes) -> bool:
276
+ return data[0] == 147 and data[1] == 78
277
+
278
+
279
+ def is_sf_audio_data(data: bytes) -> bool:
280
+ is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70
281
+ is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97
282
+ is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103
283
+ return is_wav or is_flac or is_ogg
284
+
285
+
286
+ def mmap_read(path: str, offset: int, length: int) -> bytes:
287
+ with open(path, "rb") as f:
288
+ with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
289
+ data = mmap_o[offset : offset + length]
290
+ return data
291
+
292
+
293
+ def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes:
294
+ return mmap_read(zip_path, offset, length)
295
+
296
+
297
+ def parse_path(path: str) -> Tuple[str, List[int]]:
298
+ """Parse data path which is either a path to
299
+ 1. a .npy/.wav/.flac/.ogg file
300
+ 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]"
301
+
302
+ Args:
303
+ path (str): the data path to parse
304
+
305
+ Returns:
306
+ file_path (str): the file path
307
+ slice_ptr (list of int): empty in case 1;
308
+ byte offset and length for the slice in case 2
309
+ """
310
+
311
+ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS:
312
+ _path, slice_ptr = path, []
313
+ else:
314
+ _path, *slice_ptr = path.split(":")
315
+ if not Path(_path).is_file():
316
+ raise FileNotFoundError(f"File not found: {_path}")
317
+ assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}"
318
+ slice_ptr = [int(i) for i in slice_ptr]
319
+ return _path, slice_ptr
320
+
321
+
322
+ def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
323
+ padding = n_fft - win_length
324
+ assert padding >= 0
325
+ return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
326
+
327
+
328
+ def get_fourier_basis(n_fft: int) -> torch.Tensor:
329
+ basis = np.fft.fft(np.eye(n_fft))
330
+ basis = np.vstack(
331
+ [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
332
+ )
333
+ return torch.from_numpy(basis).float()
334
+
335
+
336
+ def get_mel_filters(
337
+ sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
338
+ ) -> torch.Tensor:
339
+ try:
340
+ import librosa
341
+ except ImportError:
342
+ raise ImportError("Please install librosa: pip install librosa")
343
+ basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max)
344
+ return torch.from_numpy(basis).float()
345
+
346
+
347
+ class TTSSpectrogram(torch.nn.Module):
348
+ def __init__(
349
+ self,
350
+ n_fft: int,
351
+ win_length: int,
352
+ hop_length: int,
353
+ window_fn: callable = torch.hann_window,
354
+ return_phase: bool = False,
355
+ ) -> None:
356
+ super(TTSSpectrogram, self).__init__()
357
+ self.n_fft = n_fft
358
+ self.hop_length = hop_length
359
+ self.return_phase = return_phase
360
+
361
+ basis = get_fourier_basis(n_fft).unsqueeze(1)
362
+ basis *= get_window(window_fn, n_fft, win_length)
363
+ self.register_buffer("basis", basis)
364
+
365
+ def forward(
366
+ self, waveform: torch.Tensor
367
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
368
+ padding = (self.n_fft // 2, self.n_fft // 2)
369
+ x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
370
+ x = F.conv1d(x, self.basis, stride=self.hop_length)
371
+ real_part = x[:, : self.n_fft // 2 + 1, :]
372
+ imag_part = x[:, self.n_fft // 2 + 1 :, :]
373
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
374
+ if self.return_phase:
375
+ phase = torch.atan2(imag_part, real_part)
376
+ return magnitude, phase
377
+ return magnitude
378
+
379
+
380
+ class TTSMelScale(torch.nn.Module):
381
+ def __init__(
382
+ self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
383
+ ) -> None:
384
+ super(TTSMelScale, self).__init__()
385
+ basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
386
+ self.register_buffer("basis", basis)
387
+
388
+ def forward(self, specgram: torch.Tensor) -> torch.Tensor:
389
+ return torch.matmul(self.basis, specgram)
fairseq/fairseq/data/audio/data_cfg.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from argparse import Namespace
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+ from typing import Dict, Optional
11
+
12
+ from fairseq.data import Dictionary
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def get_config_from_yaml(yaml_path: Path):
18
+ try:
19
+ import yaml
20
+ except ImportError:
21
+ print("Please install PyYAML: pip install PyYAML")
22
+ config = {}
23
+ if yaml_path.is_file():
24
+ try:
25
+ with open(yaml_path) as f:
26
+ config = yaml.load(f, Loader=yaml.FullLoader)
27
+ except Exception as e:
28
+ raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
29
+ else:
30
+ raise FileNotFoundError(f"{yaml_path.as_posix()} not found")
31
+
32
+ return config
33
+
34
+
35
+ class S2TDataConfig(object):
36
+ """Wrapper class for data config YAML"""
37
+
38
+ def __init__(self, yaml_path: Path):
39
+ self.config = get_config_from_yaml(yaml_path)
40
+ self.root = yaml_path.parent
41
+
42
+ def _auto_convert_to_abs_path(self, x):
43
+ if isinstance(x, str):
44
+ if not Path(x).exists() and (self.root / x).exists():
45
+ return (self.root / x).as_posix()
46
+ elif isinstance(x, dict):
47
+ return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
48
+ return x
49
+
50
+ @property
51
+ def vocab_filename(self):
52
+ """fairseq vocabulary file under data root"""
53
+ return self.config.get("vocab_filename", "dict.txt")
54
+
55
+ @property
56
+ def speaker_set_filename(self):
57
+ """speaker set file under data root"""
58
+ return self.config.get("speaker_set_filename", None)
59
+
60
+ @property
61
+ def shuffle(self) -> bool:
62
+ """Shuffle dataset samples before batching"""
63
+ return self.config.get("shuffle", False)
64
+
65
+ @property
66
+ def pre_tokenizer(self) -> Dict:
67
+ """Pre-tokenizer to apply before subword tokenization. Returning
68
+ a dictionary with `tokenizer` providing the tokenizer name and
69
+ the other items providing the tokenizer-specific arguments.
70
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
71
+ tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
72
+ return self._auto_convert_to_abs_path(tokenizer)
73
+
74
+ @property
75
+ def bpe_tokenizer(self) -> Dict:
76
+ """Subword tokenizer to apply after pre-tokenization. Returning
77
+ a dictionary with `bpe` providing the tokenizer name and
78
+ the other items providing the tokenizer-specific arguments.
79
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
80
+ tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
81
+ return self._auto_convert_to_abs_path(tokenizer)
82
+
83
+ @property
84
+ def prepend_tgt_lang_tag(self) -> bool:
85
+ """Prepend target lang ID token as the target BOS (e.g. for to-many
86
+ multilingual setting). During inference, this requires `--prefix-size 1`
87
+ to force BOS to be lang ID token."""
88
+ return self.config.get("prepend_tgt_lang_tag", False)
89
+
90
+ @property
91
+ def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
92
+ """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
93
+ return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
94
+
95
+ @property
96
+ def input_feat_per_channel(self):
97
+ """The dimension of input features (per audio channel)"""
98
+ return self.config.get("input_feat_per_channel", 80)
99
+
100
+ @property
101
+ def input_channels(self):
102
+ """The number of channels in the input audio"""
103
+ return self.config.get("input_channels", 1)
104
+
105
+ @property
106
+ def sample_rate(self):
107
+ return self.config.get("sample_rate", 16_000)
108
+
109
+ @property
110
+ def sampling_alpha(self):
111
+ """Hyper-parameter alpha = 1/T for temperature-based resampling.
112
+ (alpha = 1 for no resampling)"""
113
+ return self.config.get("sampling_alpha", 1.0)
114
+
115
+ @property
116
+ def use_audio_input(self):
117
+ """Needed by the dataset loader to see if the model requires
118
+ raw audio as inputs."""
119
+ return self.config.get("use_audio_input", False)
120
+
121
+ def standardize_audio(self) -> bool:
122
+ return self.use_audio_input and self.config.get("standardize_audio", False)
123
+
124
+ @property
125
+ def use_sample_rate(self):
126
+ """Needed by the dataset loader to see if the model requires
127
+ raw audio with specific sample rate as inputs."""
128
+ return self.config.get("use_sample_rate", 16000)
129
+
130
+ @property
131
+ def audio_root(self):
132
+ """Audio paths in the manifest TSV can be relative and this provides
133
+ the root path. Set this to empty string when using absolute paths."""
134
+ return self.config.get("audio_root", "")
135
+
136
+ def get_transforms(self, transform_type, split, is_train):
137
+ """Split-specific feature transforms. Allowing train set
138
+ wildcard `_train`, evaluation set wildcard `_eval` and general
139
+ wildcard `*` for matching."""
140
+ from copy import deepcopy
141
+
142
+ cfg = deepcopy(self.config)
143
+ _cur = cfg.get(f"{transform_type}transforms", {})
144
+ cur = _cur.get(split)
145
+ cur = _cur.get("_train") if cur is None and is_train else cur
146
+ cur = _cur.get("_eval") if cur is None and not is_train else cur
147
+ cur = _cur.get("*") if cur is None else cur
148
+ return cur
149
+
150
+ def get_feature_transforms(self, split, is_train):
151
+ cfg = deepcopy(self.config)
152
+ # TODO: deprecate transforms
153
+ cur = self.get_transforms("", split, is_train)
154
+ if cur is not None:
155
+ logger.warning(
156
+ "Auto converting transforms into feature_transforms, "
157
+ "but transforms will be deprecated in the future. Please "
158
+ "update this in the config."
159
+ )
160
+ ft_transforms = self.get_transforms("feature_", split, is_train)
161
+ if ft_transforms:
162
+ cur.extend(ft_transforms)
163
+ else:
164
+ cur = self.get_transforms("feature_", split, is_train)
165
+ cfg["feature_transforms"] = cur
166
+ return cfg
167
+
168
+ def get_waveform_transforms(self, split, is_train):
169
+ cfg = deepcopy(self.config)
170
+ cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
171
+ return cfg
172
+
173
+ def get_dataset_transforms(self, split, is_train):
174
+ cfg = deepcopy(self.config)
175
+ cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
176
+ return cfg
177
+
178
+ @property
179
+ def global_cmvn_stats_npz(self) -> Optional[str]:
180
+ path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
181
+ return self._auto_convert_to_abs_path(path)
182
+
183
+ @property
184
+ def vocoder(self) -> Dict[str, str]:
185
+ vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
186
+ return self._auto_convert_to_abs_path(vocoder)
187
+
188
+ @property
189
+ def hub(self) -> Dict[str, str]:
190
+ return self.config.get("hub", {})
191
+
192
+
193
+ class S2SDataConfig(S2TDataConfig):
194
+ """Wrapper class for data config YAML"""
195
+
196
+ @property
197
+ def vocab_filename(self):
198
+ """fairseq vocabulary file under data root"""
199
+ return self.config.get("vocab_filename", None)
200
+
201
+ @property
202
+ def pre_tokenizer(self) -> Dict:
203
+ return None
204
+
205
+ @property
206
+ def bpe_tokenizer(self) -> Dict:
207
+ return None
208
+
209
+ @property
210
+ def input_transformed_channels(self):
211
+ """The number of channels in the audio after feature transforms"""
212
+ # TODO: move this into individual transforms
213
+ # TODO: deprecate transforms
214
+ _cur = self.config.get("transforms", {})
215
+ ft_transforms = self.config.get("feature_transforms", {})
216
+ if _cur and ft_transforms:
217
+ _cur.update(ft_transforms)
218
+ else:
219
+ _cur = self.config.get("feature_transforms", {})
220
+ cur = _cur.get("_train", [])
221
+
222
+ _channels = self.input_channels
223
+ if "delta_deltas" in cur:
224
+ _channels *= 3
225
+
226
+ return _channels
227
+
228
+ @property
229
+ def output_sample_rate(self):
230
+ """The audio sample rate of output target speech"""
231
+ return self.config.get("output_sample_rate", 22050)
232
+
233
+ @property
234
+ def target_speaker_embed(self):
235
+ """Target speaker embedding file (one line per target audio sample)"""
236
+ return self.config.get("target_speaker_embed", None)
237
+
238
+ @property
239
+ def prepend_tgt_lang_tag_as_bos(self) -> bool:
240
+ """Prepend target lang ID token as the target BOS."""
241
+ return self.config.get("prepend_tgt_lang_tag_as_bos", False)
242
+
243
+
244
+ class MultitaskConfig(object):
245
+ """Wrapper class for data config YAML"""
246
+
247
+ def __init__(self, yaml_path: Path):
248
+ config = get_config_from_yaml(yaml_path)
249
+ self.config = {}
250
+ for k, v in config.items():
251
+ self.config[k] = SingleTaskConfig(k, v)
252
+
253
+ def get_all_tasks(self):
254
+ return self.config
255
+
256
+ def get_single_task(self, name):
257
+ assert name in self.config, f"multitask '{name}' does not exist!"
258
+ return self.config[name]
259
+
260
+ @property
261
+ def first_pass_decoder_task_index(self):
262
+ """Return the task index of the first-pass text decoder.
263
+ If there are multiple 'is_first_pass_decoder: True' in the config file,
264
+ the last task is used for the first-pass decoder.
265
+ If there is no 'is_first_pass_decoder: True' in the config file,
266
+ the last task whose task_name includes 'target' and decoder_type is not ctc.
267
+ """
268
+ idx = -1
269
+ for i, (k, v) in enumerate(self.config.items()):
270
+ if v.is_first_pass_decoder:
271
+ idx = i
272
+ if idx < 0:
273
+ for i, (k, v) in enumerate(self.config.items()):
274
+ if k.startswith("target") and v.decoder_type == "transformer":
275
+ idx = i
276
+ return idx
277
+
278
+
279
+ class SingleTaskConfig(object):
280
+ def __init__(self, name, config):
281
+ self.task_name = name
282
+ self.config = config
283
+ dict_path = config.get("dict", "")
284
+ self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None
285
+
286
+ @property
287
+ def data(self):
288
+ return self.config.get("data", "")
289
+
290
+ @property
291
+ def decoder_type(self):
292
+ return self.config.get("decoder_type", "transformer")
293
+
294
+ @property
295
+ def decoder_args(self):
296
+ """Decoder arch related args"""
297
+ args = self.config.get("decoder_args", {})
298
+ return Namespace(**args)
299
+
300
+ @property
301
+ def criterion_cfg(self):
302
+ """cfg for the multitask criterion"""
303
+ if self.decoder_type == "ctc":
304
+ from fairseq.criterions.ctc import CtcCriterionConfig
305
+
306
+ cfg = CtcCriterionConfig
307
+ cfg.zero_infinity = self.config.get("zero_infinity", True)
308
+ else:
309
+ from fairseq.criterions.label_smoothed_cross_entropy import (
310
+ LabelSmoothedCrossEntropyCriterionConfig,
311
+ )
312
+
313
+ cfg = LabelSmoothedCrossEntropyCriterionConfig
314
+ cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
315
+ return cfg
316
+
317
+ @property
318
+ def input_from(self):
319
+ """Condition on encoder/decoder of the main model"""
320
+ return "decoder" if "decoder_layer" in self.config else "encoder"
321
+
322
+ @property
323
+ def input_layer(self):
324
+ if self.input_from == "decoder":
325
+ return self.config["decoder_layer"] - 1
326
+ else:
327
+ # default using the output from the last encoder layer (-1)
328
+ return self.config.get("encoder_layer", 0) - 1
329
+
330
+ @property
331
+ def loss_weight_schedule(self):
332
+ return (
333
+ "decay"
334
+ if "loss_weight_max" in self.config
335
+ and "loss_weight_decay_steps" in self.config
336
+ else "fixed"
337
+ )
338
+
339
+ def get_loss_weight(self, num_updates):
340
+ if self.loss_weight_schedule == "fixed":
341
+ weight = self.config.get("loss_weight", 1.0)
342
+ else: # "decay"
343
+ assert (
344
+ self.config.get("loss_weight_decay_steps", 0) > 0
345
+ ), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
346
+ loss_weight_min = self.config.get("loss_weight_min", 0.0001)
347
+ loss_weight_decay_stepsize = (
348
+ self.config["loss_weight_max"] - loss_weight_min
349
+ ) / self.config["loss_weight_decay_steps"]
350
+ weight = max(
351
+ self.config["loss_weight_max"]
352
+ - loss_weight_decay_stepsize * num_updates,
353
+ loss_weight_min,
354
+ )
355
+ return weight
356
+
357
+ @property
358
+ def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
359
+ """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
360
+ return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
361
+
362
+ @property
363
+ def eos_token(self):
364
+ """EOS token during generation"""
365
+ return self.config.get("eos_token", "<eos>")
366
+
367
+ @property
368
+ def rdrop_alpha(self):
369
+ return self.config.get("rdrop_alpha", 0.0)
370
+
371
+ @property
372
+ def is_first_pass_decoder(self):
373
+ flag = self.config.get("is_first_pass_decoder", False)
374
+ if flag:
375
+ if self.decoder_type == "ctc":
376
+ raise ValueError(
377
+ "First-pass decoder in the multi-decoder model must not be CTC."
378
+ )
379
+ if "target" not in self.task_name:
380
+ raise Warning(
381
+ 'The name of the first-pass decoder does not include "target".'
382
+ )
383
+ return flag
384
+
385
+ @property
386
+ def get_lang_tag_mapping(self):
387
+ return self.config.get("lang_tag_mapping", {})
fairseq/fairseq/data/audio/dataset_transforms/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioDatasetTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_DATASET_TRANSFORM_REGISTRY = {}
15
+ AUDIO_DATASET_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_dataset_transform(name):
19
+ return AUDIO_DATASET_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_dataset_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioDatasetTransform,
26
+ AUDIO_DATASET_TRANSFORM_REGISTRY,
27
+ AUDIO_DATASET_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "dataset")
32
+
33
+
34
+ class CompositeAudioDatasetTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "dataset",
40
+ get_audio_dataset_transform,
41
+ CompositeAudioDatasetTransform,
42
+ config,
43
+ return_empty=True,
44
+ )
45
+
46
+ def get_transform(self, cls):
47
+ for t in self.transforms:
48
+ if isinstance(t, cls):
49
+ return t
50
+ return None
51
+
52
+ def has_transform(self, cls):
53
+ return self.get_transform(cls) is not None
fairseq/fairseq/data/audio/dataset_transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
fairseq/fairseq/data/audio/dataset_transforms/__pycache__/noisyoverlapaugment.cpython-310.pyc ADDED
Binary file (3.02 kB). View file
 
fairseq/fairseq/data/audio/dataset_transforms/concataugment.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import numpy as np
3
+
4
+ from fairseq.data.audio.dataset_transforms import (
5
+ AudioDatasetTransform,
6
+ register_audio_dataset_transform,
7
+ )
8
+
9
+ _DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
10
+
11
+
12
+ @register_audio_dataset_transform("concataugment")
13
+ class ConcatAugment(AudioDatasetTransform):
14
+ @classmethod
15
+ def from_config_dict(cls, config=None):
16
+ _config = {} if config is None else config
17
+ return ConcatAugment(
18
+ _config.get("rate", _DEFAULTS["rate"]),
19
+ _config.get("max_tokens", _DEFAULTS["max_tokens"]),
20
+ _config.get("attempts", _DEFAULTS["attempts"]),
21
+ )
22
+
23
+ def __init__(
24
+ self,
25
+ rate=_DEFAULTS["rate"],
26
+ max_tokens=_DEFAULTS["max_tokens"],
27
+ attempts=_DEFAULTS["attempts"],
28
+ ):
29
+ self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
30
+
31
+ def __repr__(self):
32
+ return (
33
+ self.__class__.__name__
34
+ + "("
35
+ + ", ".join(
36
+ [
37
+ f"rate={self.rate}",
38
+ f"max_tokens={self.max_tokens}",
39
+ f"attempts={self.attempts}",
40
+ ]
41
+ )
42
+ + ")"
43
+ )
44
+
45
+ def find_indices(self, index: int, n_frames: List[int], n_samples: int):
46
+ # skip conditions: application rate, max_tokens limit exceeded
47
+ if np.random.random() > self.rate:
48
+ return [index]
49
+ if self.max_tokens and n_frames[index] > self.max_tokens:
50
+ return [index]
51
+
52
+ # pick second sample to concatenate
53
+ for _ in range(self.attempts):
54
+ index2 = np.random.randint(0, n_samples)
55
+ if index2 != index and (
56
+ not self.max_tokens
57
+ or n_frames[index] + n_frames[index2] < self.max_tokens
58
+ ):
59
+ return [index, index2]
60
+
61
+ return [index]
fairseq/fairseq/data/audio/dataset_transforms/noisyoverlapaugment.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from fairseq.data.audio import rand_uniform
5
+ from fairseq.data.audio.dataset_transforms import (
6
+ AudioDatasetTransform,
7
+ register_audio_dataset_transform,
8
+ )
9
+ from fairseq.data.audio.waveform_transforms.noiseaugment import (
10
+ NoiseAugmentTransform,
11
+ )
12
+
13
+ _DEFAULTS = {
14
+ "rate": 0.25,
15
+ "mixing_noise_rate": 0.1,
16
+ "noise_path": "",
17
+ "noise_snr_min": -5,
18
+ "noise_snr_max": 5,
19
+ "utterance_snr_min": -5,
20
+ "utterance_snr_max": 5,
21
+ }
22
+
23
+
24
+ @register_audio_dataset_transform("noisyoverlapaugment")
25
+ class NoisyOverlapAugment(AudioDatasetTransform):
26
+ @classmethod
27
+ def from_config_dict(cls, config=None):
28
+ _config = {} if config is None else config
29
+ return NoisyOverlapAugment(
30
+ _config.get("rate", _DEFAULTS["rate"]),
31
+ _config.get("mixing_noise_rate", _DEFAULTS["mixing_noise_rate"]),
32
+ _config.get("noise_path", _DEFAULTS["noise_path"]),
33
+ _config.get("noise_snr_min", _DEFAULTS["noise_snr_min"]),
34
+ _config.get("noise_snr_max", _DEFAULTS["noise_snr_max"]),
35
+ _config.get("utterance_snr_min", _DEFAULTS["utterance_snr_min"]),
36
+ _config.get("utterance_snr_max", _DEFAULTS["utterance_snr_max"]),
37
+ )
38
+
39
+ def __init__(
40
+ self,
41
+ rate=_DEFAULTS["rate"],
42
+ mixing_noise_rate=_DEFAULTS["mixing_noise_rate"],
43
+ noise_path=_DEFAULTS["noise_path"],
44
+ noise_snr_min=_DEFAULTS["noise_snr_min"],
45
+ noise_snr_max=_DEFAULTS["noise_snr_max"],
46
+ utterance_snr_min=_DEFAULTS["utterance_snr_min"],
47
+ utterance_snr_max=_DEFAULTS["utterance_snr_max"],
48
+ ):
49
+ self.rate = rate
50
+ self.mixing_noise_rate = mixing_noise_rate
51
+ self.noise_shaper = NoiseAugmentTransform(noise_path)
52
+ self.noise_snr_min = noise_snr_min
53
+ self.noise_snr_max = noise_snr_max
54
+ self.utterance_snr_min = utterance_snr_min
55
+ self.utterance_snr_max = utterance_snr_max
56
+
57
+ def __repr__(self):
58
+ return (
59
+ self.__class__.__name__
60
+ + "("
61
+ + ", ".join(
62
+ [
63
+ f"rate={self.rate}",
64
+ f"mixing_noise_rate={self.mixing_noise_rate}",
65
+ f"noise_snr_min={self.noise_snr_min}",
66
+ f"noise_snr_max={self.noise_snr_max}",
67
+ f"utterance_snr_min={self.utterance_snr_min}",
68
+ f"utterance_snr_max={self.utterance_snr_max}",
69
+ ]
70
+ )
71
+ + ")"
72
+ )
73
+
74
+ def __call__(self, sources):
75
+ for i, source in enumerate(sources):
76
+ if np.random.random() > self.rate:
77
+ continue
78
+
79
+ pri = source.numpy()
80
+
81
+ if np.random.random() > self.mixing_noise_rate:
82
+ sec = sources[np.random.randint(0, len(sources))].numpy()
83
+ snr = rand_uniform(self.utterance_snr_min, self.utterance_snr_max)
84
+ else:
85
+ sec = self.noise_shaper.pick_sample(source.shape)
86
+ snr = rand_uniform(self.noise_snr_min, self.noise_snr_max)
87
+
88
+ L1 = pri.shape[-1]
89
+ L2 = sec.shape[-1]
90
+ l = np.random.randint(0, min(round(L1 / 2), L2)) # mix len
91
+ s_source = np.random.randint(0, L1 - l)
92
+ s_sec = np.random.randint(0, L2 - l)
93
+
94
+ get_power = lambda x: np.mean(x**2)
95
+ if get_power(sec) == 0:
96
+ continue
97
+
98
+ scl = np.sqrt(get_power(pri) / (np.power(10, snr / 10) * get_power(sec)))
99
+
100
+ pri[s_source : s_source + l] = np.add(
101
+ pri[s_source : s_source + l], np.multiply(scl, sec[s_sec : s_sec + l])
102
+ )
103
+ sources[i] = torch.from_numpy(pri).float()
104
+
105
+ return sources
fairseq/fairseq/data/audio/feature_transforms/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioFeatureTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
15
+ AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_feature_transform(name):
19
+ return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_feature_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioFeatureTransform,
26
+ AUDIO_FEATURE_TRANSFORM_REGISTRY,
27
+ AUDIO_FEATURE_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "feature")
32
+
33
+
34
+ class CompositeAudioFeatureTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "feature",
40
+ get_audio_feature_transform,
41
+ CompositeAudioFeatureTransform,
42
+ config,
43
+ )
fairseq/fairseq/data/audio/feature_transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
fairseq/fairseq/data/audio/feature_transforms/__pycache__/global_cmvn.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
fairseq/fairseq/data/audio/feature_transforms/__pycache__/specaugment.cpython-310.pyc ADDED
Binary file (3.37 kB). View file
 
fairseq/fairseq/data/audio/feature_transforms/__pycache__/utterance_cmvn.cpython-310.pyc ADDED
Binary file (1.65 kB). View file
 
fairseq/fairseq/data/audio/feature_transforms/delta_deltas.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from fairseq.data.audio.feature_transforms import (
4
+ AudioFeatureTransform,
5
+ register_audio_feature_transform,
6
+ )
7
+
8
+
9
+ @register_audio_feature_transform("delta_deltas")
10
+ class DeltaDeltas(AudioFeatureTransform):
11
+ """Expand delta-deltas features from spectrum."""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return DeltaDeltas(_config.get("win_length", 5))
17
+
18
+ def __init__(self, win_length=5):
19
+ self.win_length = win_length
20
+
21
+ def __repr__(self):
22
+ return self.__class__.__name__
23
+
24
+ def __call__(self, spectrogram):
25
+ from torchaudio.functional import compute_deltas
26
+
27
+ assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
28
+ # spectrogram is T x F, while compute_deltas takes (…, F, T)
29
+ spectrogram = torch.from_numpy(spectrogram).transpose(0, 1)
30
+ delta = compute_deltas(spectrogram)
31
+ delta_delta = compute_deltas(delta)
32
+
33
+ out_feat = np.concatenate(
34
+ [spectrogram, delta.numpy(), delta_delta.numpy()], axis=0
35
+ )
36
+ out_feat = np.transpose(out_feat)
37
+ return out_feat
fairseq/fairseq/data/audio/feature_transforms/specaugment.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numbers
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from fairseq.data.audio.feature_transforms import (
7
+ AudioFeatureTransform,
8
+ register_audio_feature_transform,
9
+ )
10
+
11
+
12
+ @register_audio_feature_transform("specaugment")
13
+ class SpecAugmentTransform(AudioFeatureTransform):
14
+ """SpecAugment (https://arxiv.org/abs/1904.08779)"""
15
+
16
+ @classmethod
17
+ def from_config_dict(cls, config=None):
18
+ _config = {} if config is None else config
19
+ return SpecAugmentTransform(
20
+ _config.get("time_warp_W", 0),
21
+ _config.get("freq_mask_N", 0),
22
+ _config.get("freq_mask_F", 0),
23
+ _config.get("time_mask_N", 0),
24
+ _config.get("time_mask_T", 0),
25
+ _config.get("time_mask_p", 0.0),
26
+ _config.get("mask_value", None),
27
+ )
28
+
29
+ def __init__(
30
+ self,
31
+ time_warp_w: int = 0,
32
+ freq_mask_n: int = 0,
33
+ freq_mask_f: int = 0,
34
+ time_mask_n: int = 0,
35
+ time_mask_t: int = 0,
36
+ time_mask_p: float = 0.0,
37
+ mask_value: Optional[float] = 0.0,
38
+ ):
39
+ # Sanity checks
40
+ assert mask_value is None or isinstance(
41
+ mask_value, numbers.Number
42
+ ), f"mask_value (type: {type(mask_value)}) must be None or a number"
43
+ if freq_mask_n > 0:
44
+ assert freq_mask_f > 0, (
45
+ f"freq_mask_F ({freq_mask_f}) "
46
+ f"must be larger than 0 when doing freq masking."
47
+ )
48
+ if time_mask_n > 0:
49
+ assert time_mask_t > 0, (
50
+ f"time_mask_T ({time_mask_t}) must be larger than 0 when "
51
+ f"doing time masking."
52
+ )
53
+
54
+ self.time_warp_w = time_warp_w
55
+ self.freq_mask_n = freq_mask_n
56
+ self.freq_mask_f = freq_mask_f
57
+ self.time_mask_n = time_mask_n
58
+ self.time_mask_t = time_mask_t
59
+ self.time_mask_p = time_mask_p
60
+ self.mask_value = mask_value
61
+
62
+ def __repr__(self):
63
+ return (
64
+ self.__class__.__name__
65
+ + "("
66
+ + ", ".join(
67
+ [
68
+ f"time_warp_w={self.time_warp_w}",
69
+ f"freq_mask_n={self.freq_mask_n}",
70
+ f"freq_mask_f={self.freq_mask_f}",
71
+ f"time_mask_n={self.time_mask_n}",
72
+ f"time_mask_t={self.time_mask_t}",
73
+ f"time_mask_p={self.time_mask_p}",
74
+ ]
75
+ )
76
+ + ")"
77
+ )
78
+
79
+ def __call__(self, spectrogram):
80
+ assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
81
+
82
+ distorted = spectrogram.copy() # make a copy of input spectrogram.
83
+ num_frames = spectrogram.shape[0] # or 'tau' in the paper.
84
+ num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
85
+ mask_value = self.mask_value
86
+
87
+ if mask_value is None: # if no value was specified, use local mean.
88
+ mask_value = spectrogram.mean()
89
+
90
+ if num_frames == 0:
91
+ return spectrogram
92
+
93
+ if num_freqs < self.freq_mask_f:
94
+ return spectrogram
95
+
96
+ if self.time_warp_w > 0:
97
+ if 2 * self.time_warp_w < num_frames:
98
+ import cv2
99
+
100
+ w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
101
+ w = np.random.randint(-self.time_warp_w + 1, self.time_warp_w)
102
+ upper, lower = distorted[:w0, :], distorted[w0:, :]
103
+ upper = cv2.resize(
104
+ upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
105
+ )
106
+ lower = cv2.resize(
107
+ lower,
108
+ dsize=(num_freqs, num_frames - w0 - w),
109
+ interpolation=cv2.INTER_LINEAR,
110
+ )
111
+ distorted = np.concatenate((upper, lower), axis=0)
112
+
113
+ for _i in range(self.freq_mask_n):
114
+ f = np.random.randint(0, self.freq_mask_f)
115
+ f0 = np.random.randint(0, num_freqs - f)
116
+ if f != 0:
117
+ distorted[:, f0 : f0 + f] = mask_value
118
+
119
+ max_time_mask_t = min(
120
+ self.time_mask_t, math.floor(num_frames * self.time_mask_p)
121
+ )
122
+ if max_time_mask_t < 1:
123
+ return distorted
124
+
125
+ for _i in range(self.time_mask_n):
126
+ t = np.random.randint(0, max_time_mask_t)
127
+ t0 = np.random.randint(0, num_frames - t)
128
+ if t != 0:
129
+ distorted[t0 : t0 + t, :] = mask_value
130
+
131
+ return distorted
fairseq/fairseq/data/audio/feature_transforms/utterance_cmvn.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from fairseq.data.audio.feature_transforms import (
4
+ AudioFeatureTransform,
5
+ register_audio_feature_transform,
6
+ )
7
+
8
+
9
+ @register_audio_feature_transform("utterance_cmvn")
10
+ class UtteranceCMVN(AudioFeatureTransform):
11
+ """Utterance-level CMVN (cepstral mean and variance normalization)"""
12
+
13
+ @classmethod
14
+ def from_config_dict(cls, config=None):
15
+ _config = {} if config is None else config
16
+ return UtteranceCMVN(
17
+ _config.get("norm_means", True),
18
+ _config.get("norm_vars", True),
19
+ )
20
+
21
+ def __init__(self, norm_means=True, norm_vars=True):
22
+ self.norm_means, self.norm_vars = norm_means, norm_vars
23
+
24
+ def __repr__(self):
25
+ return (
26
+ self.__class__.__name__
27
+ + f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
28
+ )
29
+
30
+ def __call__(self, x):
31
+ mean = x.mean(axis=0)
32
+ square_sums = (x**2).sum(axis=0)
33
+
34
+ if self.norm_means:
35
+ x = np.subtract(x, mean)
36
+ if self.norm_vars:
37
+ var = square_sums / x.shape[0] - mean**2
38
+ std = np.sqrt(np.maximum(var, 1e-10))
39
+ x = np.divide(x, std)
40
+
41
+ return x
fairseq/fairseq/data/audio/frm_text_to_speech_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.abs
7
+
8
+ import csv
9
+ import logging
10
+ import os.path as op
11
+ from typing import List, Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from fairseq.data import Dictionary
16
+ from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
17
+ from fairseq.data.audio.text_to_speech_dataset import (
18
+ TextToSpeechDataset,
19
+ TextToSpeechDatasetCreator,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class FrmTextToSpeechDataset(TextToSpeechDataset):
26
+ def __init__(
27
+ self,
28
+ split: str,
29
+ is_train_split: bool,
30
+ data_cfg: S2TDataConfig,
31
+ audio_paths: List[str],
32
+ n_frames: List[int],
33
+ src_texts: Optional[List[str]] = None,
34
+ tgt_texts: Optional[List[str]] = None,
35
+ speakers: Optional[List[str]] = None,
36
+ src_langs: Optional[List[str]] = None,
37
+ tgt_langs: Optional[List[str]] = None,
38
+ ids: Optional[List[str]] = None,
39
+ tgt_dict: Optional[Dictionary] = None,
40
+ pre_tokenizer=None,
41
+ bpe_tokenizer=None,
42
+ n_frames_per_step=1,
43
+ speaker_to_id=None,
44
+ do_chunk=False,
45
+ chunk_bound=-1,
46
+ chunk_init=50,
47
+ chunk_incr=5,
48
+ add_eos=True,
49
+ dedup=True,
50
+ ref_fpu=-1,
51
+ ):
52
+ # It assumes texts are encoded at a fixed frame-rate
53
+ super().__init__(
54
+ split=split,
55
+ is_train_split=is_train_split,
56
+ data_cfg=data_cfg,
57
+ audio_paths=audio_paths,
58
+ n_frames=n_frames,
59
+ src_texts=src_texts,
60
+ tgt_texts=tgt_texts,
61
+ speakers=speakers,
62
+ src_langs=src_langs,
63
+ tgt_langs=tgt_langs,
64
+ ids=ids,
65
+ tgt_dict=tgt_dict,
66
+ pre_tokenizer=pre_tokenizer,
67
+ bpe_tokenizer=bpe_tokenizer,
68
+ n_frames_per_step=n_frames_per_step,
69
+ speaker_to_id=speaker_to_id,
70
+ )
71
+
72
+ self.do_chunk = do_chunk
73
+ self.chunk_bound = chunk_bound
74
+ self.chunk_init = chunk_init
75
+ self.chunk_incr = chunk_incr
76
+ self.add_eos = add_eos
77
+ self.dedup = dedup
78
+ self.ref_fpu = ref_fpu
79
+
80
+ self.chunk_size = -1
81
+
82
+ if do_chunk:
83
+ assert self.chunk_incr >= 0
84
+ assert self.pre_tokenizer is None
85
+
86
+ def __getitem__(self, index):
87
+ index, source, target, speaker_id, _, _, _ = super().__getitem__(index)
88
+ if target[-1].item() == self.tgt_dict.eos_index:
89
+ target = target[:-1]
90
+
91
+ fpu = source.size(0) / target.size(0) # frame-per-unit
92
+ fps = self.n_frames_per_step
93
+ assert (
94
+ self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
95
+ ), f"{fpu*fps} != {self.ref_fpu}"
96
+
97
+ # only chunk training split
98
+ if self.is_train_split and self.do_chunk and self.chunk_size > 0:
99
+ lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
100
+ text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
101
+ size = len(text)
102
+ chunk_size = min(self.chunk_size, size)
103
+ chunk_start = np.random.randint(size - chunk_size + 1)
104
+ text = text[chunk_start : chunk_start + chunk_size]
105
+ target = torch.cat((lang, text), 0)
106
+
107
+ f_size = int(np.floor(chunk_size * fpu))
108
+ f_start = int(np.floor(chunk_start * fpu))
109
+ assert f_size > 0
110
+ source = source[f_start : f_start + f_size, :]
111
+
112
+ if self.dedup:
113
+ target = torch.unique_consecutive(target)
114
+
115
+ if self.add_eos:
116
+ eos_idx = self.tgt_dict.eos_index
117
+ target = torch.cat((target, torch.LongTensor([eos_idx])), 0)
118
+
119
+ return index, source, target, speaker_id
120
+
121
+ def set_epoch(self, epoch):
122
+ if self.is_train_split and self.do_chunk:
123
+ old = self.chunk_size
124
+ self.chunk_size = self.chunk_init + epoch * self.chunk_incr
125
+ if self.chunk_bound > 0:
126
+ self.chunk_size = min(self.chunk_size, self.chunk_bound)
127
+ logger.info(
128
+ (
129
+ f"{self.split}: setting chunk size "
130
+ f"from {old} to {self.chunk_size}"
131
+ )
132
+ )
133
+
134
+
135
+ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
136
+ # inherit for key names
137
+ @classmethod
138
+ def from_tsv(
139
+ cls,
140
+ root: str,
141
+ data_cfg: S2TDataConfig,
142
+ split: str,
143
+ tgt_dict,
144
+ pre_tokenizer,
145
+ bpe_tokenizer,
146
+ is_train_split: bool,
147
+ n_frames_per_step: int,
148
+ speaker_to_id,
149
+ do_chunk: bool = False,
150
+ chunk_bound: int = -1,
151
+ chunk_init: int = 50,
152
+ chunk_incr: int = 5,
153
+ add_eos: bool = True,
154
+ dedup: bool = True,
155
+ ref_fpu: float = -1,
156
+ ) -> FrmTextToSpeechDataset:
157
+ tsv_path = op.join(root, f"{split}.tsv")
158
+ if not op.isfile(tsv_path):
159
+ raise FileNotFoundError(f"Dataset not found: {tsv_path}")
160
+ with open(tsv_path) as f:
161
+ reader = csv.DictReader(
162
+ f,
163
+ delimiter="\t",
164
+ quotechar=None,
165
+ doublequote=False,
166
+ lineterminator="\n",
167
+ quoting=csv.QUOTE_NONE,
168
+ )
169
+ s = [dict(e) for e in reader]
170
+ assert len(s) > 0
171
+
172
+ ids = [ss[cls.KEY_ID] for ss in s]
173
+ audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
174
+ n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
175
+ tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
176
+ src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
177
+ speakers = [ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s]
178
+ src_langs = [ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s]
179
+ tgt_langs = [ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s]
180
+
181
+ return FrmTextToSpeechDataset(
182
+ split=split,
183
+ is_train_split=is_train_split,
184
+ data_cfg=data_cfg,
185
+ audio_paths=audio_paths,
186
+ n_frames=n_frames,
187
+ src_texts=src_texts,
188
+ tgt_texts=tgt_texts,
189
+ speakers=speakers,
190
+ src_langs=src_langs,
191
+ tgt_langs=tgt_langs,
192
+ ids=ids,
193
+ tgt_dict=tgt_dict,
194
+ pre_tokenizer=pre_tokenizer,
195
+ bpe_tokenizer=bpe_tokenizer,
196
+ n_frames_per_step=n_frames_per_step,
197
+ speaker_to_id=speaker_to_id,
198
+ do_chunk=do_chunk,
199
+ chunk_bound=chunk_bound,
200
+ chunk_init=chunk_init,
201
+ chunk_incr=chunk_incr,
202
+ add_eos=add_eos,
203
+ dedup=dedup,
204
+ ref_fpu=ref_fpu,
205
+ )
fairseq/fairseq/data/audio/hubert_dataset.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import sys
10
+ from typing import Any, List, Optional, Union
11
+
12
+ import numpy as np
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from fairseq.data import data_utils
17
+ from fairseq.data.fairseq_dataset import FairseqDataset
18
+ from fairseq.data.audio.audio_utils import (
19
+ parse_path,
20
+ read_from_stored_zip,
21
+ )
22
+ import io
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def load_audio(manifest_path, max_keep, min_keep):
28
+ n_long, n_short = 0, 0
29
+ names, inds, sizes = [], [], []
30
+ with open(manifest_path) as f:
31
+ root = f.readline().strip()
32
+ for ind, line in enumerate(f):
33
+ items = line.strip().split("\t")
34
+ assert len(items) == 2, line
35
+ sz = int(items[1])
36
+ if min_keep is not None and sz < min_keep:
37
+ n_short += 1
38
+ elif max_keep is not None and sz > max_keep:
39
+ n_long += 1
40
+ else:
41
+ names.append(items[0])
42
+ inds.append(ind)
43
+ sizes.append(sz)
44
+ tot = ind + 1
45
+ logger.info(
46
+ (
47
+ f"max_keep={max_keep}, min_keep={min_keep}, "
48
+ f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
49
+ f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
50
+ )
51
+ )
52
+ return root, names, inds, tot, sizes
53
+
54
+
55
+ def load_label(label_path, inds, tot):
56
+ with open(label_path) as f:
57
+ labels = [line.rstrip() for line in f]
58
+ assert (
59
+ len(labels) == tot
60
+ ), f"number of labels does not match ({len(labels)} != {tot})"
61
+ labels = [labels[i] for i in inds]
62
+ return labels
63
+
64
+
65
+ def load_label_offset(label_path, inds, tot):
66
+ with open(label_path) as f:
67
+ code_lengths = [len(line.encode("utf-8")) for line in f]
68
+ assert (
69
+ len(code_lengths) == tot
70
+ ), f"number of labels does not match ({len(code_lengths)} != {tot})"
71
+ offsets = list(itertools.accumulate([0] + code_lengths))
72
+ offsets = [(offsets[i], offsets[i + 1]) for i in inds]
73
+ return offsets
74
+
75
+
76
+ def verify_label_lengths(
77
+ audio_sizes,
78
+ audio_rate,
79
+ label_path,
80
+ label_rate,
81
+ inds,
82
+ tot,
83
+ tol=0.1, # tolerance in seconds
84
+ ):
85
+ if label_rate < 0:
86
+ logger.info(f"{label_path} is sequence label. skipped")
87
+ return
88
+
89
+ with open(label_path) as f:
90
+ lengths = [len(line.rstrip().split()) for line in f]
91
+ assert len(lengths) == tot
92
+ lengths = [lengths[i] for i in inds]
93
+ num_invalid = 0
94
+ for i, ind in enumerate(inds):
95
+ dur_from_audio = audio_sizes[i] / audio_rate
96
+ dur_from_label = lengths[i] / label_rate
97
+ if abs(dur_from_audio - dur_from_label) > tol:
98
+ logger.warning(
99
+ (
100
+ f"audio and label duration differ too much "
101
+ f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
102
+ f"in line {ind+1} of {label_path}. Check if `label_rate` "
103
+ f"is correctly set (currently {label_rate}). "
104
+ f"num. of samples = {audio_sizes[i]}; "
105
+ f"label length = {lengths[i]}"
106
+ )
107
+ )
108
+ num_invalid += 1
109
+ if num_invalid > 0:
110
+ logger.warning(
111
+ f"total {num_invalid} (audio, label) pairs with mismatched lengths"
112
+ )
113
+
114
+
115
+ class HubertDataset(FairseqDataset):
116
+ def __init__(
117
+ self,
118
+ manifest_path: str,
119
+ sample_rate: float,
120
+ label_paths: List[str],
121
+ label_rates: Union[List[float], float], # -1 for sequence labels
122
+ pad_list: List[str],
123
+ eos_list: List[str],
124
+ label_processors: Optional[List[Any]] = None,
125
+ max_keep_sample_size: Optional[int] = None,
126
+ min_keep_sample_size: Optional[int] = None,
127
+ max_sample_size: Optional[int] = None,
128
+ shuffle: bool = True,
129
+ pad_audio: bool = False,
130
+ normalize: bool = False,
131
+ store_labels: bool = True,
132
+ random_crop: bool = False,
133
+ single_target: bool = False,
134
+ ):
135
+ self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio(
136
+ manifest_path, max_keep_sample_size, min_keep_sample_size
137
+ )
138
+ self.sample_rate = sample_rate
139
+ self.shuffle = shuffle
140
+ self.random_crop = random_crop
141
+
142
+ self.num_labels = len(label_paths)
143
+ self.pad_list = pad_list
144
+ self.eos_list = eos_list
145
+ self.label_processors = label_processors
146
+ self.single_target = single_target
147
+ self.label_rates = (
148
+ [label_rates for _ in range(len(label_paths))]
149
+ if isinstance(label_rates, float)
150
+ else label_rates
151
+ )
152
+ self.store_labels = store_labels
153
+ if store_labels:
154
+ self.label_list = [load_label(p, inds, tot) for p in label_paths]
155
+ else:
156
+ self.label_paths = label_paths
157
+ self.label_offsets_list = [
158
+ load_label_offset(p, inds, tot) for p in label_paths
159
+ ]
160
+ assert label_processors is None or len(label_processors) == self.num_labels
161
+ for label_path, label_rate in zip(label_paths, self.label_rates):
162
+ verify_label_lengths(
163
+ self.sizes, sample_rate, label_path, label_rate, inds, tot
164
+ )
165
+
166
+ self.max_sample_size = (
167
+ max_sample_size if max_sample_size is not None else sys.maxsize
168
+ )
169
+ self.pad_audio = pad_audio
170
+ self.normalize = normalize
171
+ logger.info(
172
+ f"pad_audio={pad_audio}, random_crop={random_crop}, "
173
+ f"normalize={normalize}, max_sample_size={self.max_sample_size}"
174
+ )
175
+
176
+ def get_audio(self, index):
177
+ import soundfile as sf
178
+
179
+ wav_path = os.path.join(self.audio_root, self.audio_names[index])
180
+ _path, slice_ptr = parse_path(wav_path)
181
+ if len(slice_ptr) == 0:
182
+ wav, cur_sample_rate = sf.read(_path)
183
+ else:
184
+ assert _path.endswith(".zip")
185
+ data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
186
+ f = io.BytesIO(data)
187
+ wav, cur_sample_rate = sf.read(f)
188
+ wav = torch.from_numpy(wav).float()
189
+ wav = self.postprocess(wav, cur_sample_rate)
190
+ return wav
191
+
192
+ def get_label(self, index, label_idx):
193
+ if self.store_labels:
194
+ label = self.label_list[label_idx][index]
195
+ else:
196
+ with open(self.label_paths[label_idx]) as f:
197
+ offset_s, offset_e = self.label_offsets_list[label_idx][index]
198
+ f.seek(offset_s)
199
+ label = f.read(offset_e - offset_s)
200
+
201
+ if self.label_processors is not None:
202
+ label = self.label_processors[label_idx](label)
203
+ return label
204
+
205
+ def get_labels(self, index):
206
+ return [self.get_label(index, i) for i in range(self.num_labels)]
207
+
208
+ def __getitem__(self, index):
209
+ wav = self.get_audio(index)
210
+ labels = self.get_labels(index)
211
+ return {"id": index, "source": wav, "label_list": labels}
212
+
213
+ def __len__(self):
214
+ return len(self.sizes)
215
+
216
+ def crop_to_max_size(self, wav, target_size):
217
+ size = len(wav)
218
+ diff = size - target_size
219
+ if diff <= 0:
220
+ return wav, 0
221
+
222
+ start, end = 0, target_size
223
+ if self.random_crop:
224
+ start = np.random.randint(0, diff + 1)
225
+ end = size - diff + start
226
+ return wav[start:end], start
227
+
228
+ def collater(self, samples):
229
+ # target = max(sizes) -> random_crop not used
230
+ # target = max_sample_size -> random_crop used for long
231
+ samples = [s for s in samples if s["source"] is not None]
232
+ if len(samples) == 0:
233
+ return {}
234
+
235
+ audios = [s["source"] for s in samples]
236
+ audio_sizes = [len(s) for s in audios]
237
+ if self.pad_audio:
238
+ audio_size = min(max(audio_sizes), self.max_sample_size)
239
+ else:
240
+ audio_size = min(min(audio_sizes), self.max_sample_size)
241
+ collated_audios, padding_mask, audio_starts = self.collater_audio(
242
+ audios, audio_size
243
+ )
244
+
245
+ targets_by_label = [
246
+ [s["label_list"][i] for s in samples] for i in range(self.num_labels)
247
+ ]
248
+ targets_list, lengths_list, ntokens_list = self.collater_label(
249
+ targets_by_label, audio_size, audio_starts
250
+ )
251
+
252
+ net_input = {"source": collated_audios, "padding_mask": padding_mask}
253
+ batch = {
254
+ "id": torch.LongTensor([s["id"] for s in samples]),
255
+ "net_input": net_input,
256
+ }
257
+
258
+ if self.single_target:
259
+ batch["target_lengths"] = lengths_list[0]
260
+ batch["ntokens"] = ntokens_list[0]
261
+ batch["target"] = targets_list[0]
262
+ else:
263
+ batch["target_lengths_list"] = lengths_list
264
+ batch["ntokens_list"] = ntokens_list
265
+ batch["target_list"] = targets_list
266
+ return batch
267
+
268
+ def collater_audio(self, audios, audio_size):
269
+ collated_audios = audios[0].new_zeros(len(audios), audio_size)
270
+ padding_mask = (
271
+ torch.BoolTensor(collated_audios.shape).fill_(False)
272
+ # if self.pad_audio else None
273
+ )
274
+ audio_starts = [0 for _ in audios]
275
+ for i, audio in enumerate(audios):
276
+ diff = len(audio) - audio_size
277
+ if diff == 0:
278
+ collated_audios[i] = audio
279
+ elif diff < 0:
280
+ assert self.pad_audio
281
+ collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
282
+ padding_mask[i, diff:] = True
283
+ else:
284
+ collated_audios[i], audio_starts[i] = self.crop_to_max_size(
285
+ audio, audio_size
286
+ )
287
+ return collated_audios, padding_mask, audio_starts
288
+
289
+ def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
290
+ assert label_rate > 0
291
+ s2f = label_rate / self.sample_rate
292
+ frm_starts = [int(round(s * s2f)) for s in audio_starts]
293
+ frm_size = int(round(audio_size * s2f))
294
+ if not self.pad_audio:
295
+ rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
296
+ frm_size = min(frm_size, *rem_size)
297
+ targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
298
+ logger.debug(f"audio_starts={audio_starts}")
299
+ logger.debug(f"frame_starts={frm_starts}")
300
+ logger.debug(f"frame_size={frm_size}")
301
+
302
+ lengths = torch.LongTensor([len(t) for t in targets])
303
+ ntokens = lengths.sum().item()
304
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
305
+ return targets, lengths, ntokens
306
+
307
+ def collater_seq_label(self, targets, pad):
308
+ lengths = torch.LongTensor([len(t) for t in targets])
309
+ ntokens = lengths.sum().item()
310
+ targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
311
+ return targets, lengths, ntokens
312
+
313
+ def collater_label(self, targets_by_label, audio_size, audio_starts):
314
+ targets_list, lengths_list, ntokens_list = [], [], []
315
+ itr = zip(targets_by_label, self.label_rates, self.pad_list)
316
+ for targets, label_rate, pad in itr:
317
+ if label_rate == -1.0:
318
+ targets, lengths, ntokens = self.collater_seq_label(targets, pad)
319
+ else:
320
+ targets, lengths, ntokens = self.collater_frm_label(
321
+ targets, audio_size, audio_starts, label_rate, pad
322
+ )
323
+ targets_list.append(targets)
324
+ lengths_list.append(lengths)
325
+ ntokens_list.append(ntokens)
326
+ return targets_list, lengths_list, ntokens_list
327
+
328
+ def num_tokens(self, index):
329
+ return self.size(index)
330
+
331
+ def size(self, index):
332
+ if self.pad_audio:
333
+ return self.sizes[index]
334
+ return min(self.sizes[index], self.max_sample_size)
335
+
336
+ def ordered_indices(self):
337
+ if self.shuffle:
338
+ order = [np.random.permutation(len(self))]
339
+ else:
340
+ order = [np.arange(len(self))]
341
+
342
+ order.append(self.sizes)
343
+ return np.lexsort(order)[::-1]
344
+
345
+ def postprocess(self, wav, cur_sample_rate):
346
+ if wav.dim() == 2:
347
+ wav = wav.mean(-1)
348
+ assert wav.dim() == 1, wav.dim()
349
+
350
+ if cur_sample_rate != self.sample_rate:
351
+ raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
352
+
353
+ if self.normalize:
354
+ with torch.no_grad():
355
+ wav = F.layer_norm(wav, wav.shape)
356
+ return wav
fairseq/fairseq/data/audio/multi_modality_dataset.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ import logging
9
+ import math
10
+ from typing import List, Optional, NamedTuple
11
+
12
+ import numpy as np
13
+ import torch
14
+ from fairseq.data import (
15
+ ConcatDataset,
16
+ LanguagePairDataset,
17
+ FileAudioDataset,
18
+ data_utils,
19
+ )
20
+ from fairseq.data import FairseqDataset
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ModalityDatasetItem(NamedTuple):
26
+ datasetname: str
27
+ dataset: any
28
+ max_positions: List[int]
29
+ max_tokens: Optional[int] = None
30
+ max_sentences: Optional[int] = None
31
+
32
+
33
+ # MultiModalityDataset: it concate multiple datasets with different modalities.
34
+ # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
35
+ # 2) it adds mode to indicate what type of the data samples come from.
36
+ # It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
37
+ # from the same type of dataset
38
+ # If only one dataset is used, it will perform like the original dataset with mode added
39
+ class MultiModalityDataset(ConcatDataset):
40
+ def __init__(self, datasets: List[ModalityDatasetItem]):
41
+ id_to_mode = []
42
+ dsets = []
43
+ max_tokens = []
44
+ max_sentences = []
45
+ max_positions = []
46
+ for dset in datasets:
47
+ id_to_mode.append(dset.datasetname)
48
+ dsets.append(dset.dataset)
49
+ max_tokens.append(dset.max_tokens)
50
+ max_positions.append(dset.max_positions)
51
+ max_sentences.append(dset.max_sentences)
52
+ weights = [1.0 for s in dsets]
53
+ super().__init__(dsets, weights)
54
+ self.max_tokens = max_tokens
55
+ self.max_positions = max_positions
56
+ self.max_sentences = max_sentences
57
+ self.id_to_mode = id_to_mode
58
+ self.raw_sub_batch_samplers = []
59
+ self._cur_epoch = 0
60
+
61
+ def set_epoch(self, epoch):
62
+ super().set_epoch(epoch)
63
+ self._cur_epoch = epoch
64
+
65
+ def __getitem__(self, idx):
66
+ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
67
+ sample = self.datasets[dataset_idx][sample_idx]
68
+ return (dataset_idx, sample)
69
+
70
+ def collater(self, samples):
71
+ if len(samples) == 0:
72
+ return {}
73
+ dataset_idx = samples[0][0]
74
+ # make sure all samples in samples are from same dataset
75
+ assert sum([0 if dataset_idx == s[0] else 1 for s in samples]) == 0
76
+ samples = self.datasets[dataset_idx].collater([x[1] for x in samples])
77
+ # add mode
78
+ samples["net_input"]["mode"] = self.id_to_mode[dataset_idx]
79
+
80
+ return samples
81
+
82
+ def size(self, index: int):
83
+ if len(self.datasets) == 1:
84
+ return self.datasets[0].size(index)
85
+ return super().size(index)
86
+
87
+ @property
88
+ def sizes(self):
89
+ if len(self.datasets) == 1:
90
+ return self.datasets[0].sizes
91
+ return super().sizes
92
+
93
+ def ordered_indices(self):
94
+ """
95
+ Returns indices sorted by length. So less padding is needed.
96
+ """
97
+ if len(self.datasets) == 1:
98
+ return [self.datasets[0].ordered_indices()]
99
+ indices_group = []
100
+ for d_idx, ds in enumerate(self.datasets):
101
+ sample_num = self.cumulative_sizes[d_idx]
102
+ if d_idx > 0:
103
+ sample_num = sample_num - self.cumulative_sizes[d_idx - 1]
104
+ assert sample_num == len(ds)
105
+ indices_group.append(ds.ordered_indices())
106
+ return indices_group
107
+
108
+ def get_raw_batch_samplers(self, required_batch_size_multiple, seed):
109
+ if len(self.raw_sub_batch_samplers) > 0:
110
+ logger.info(" raw_sub_batch_samplers exists. No action is taken")
111
+ return
112
+ with data_utils.numpy_seed(seed):
113
+ indices = self.ordered_indices()
114
+
115
+ for i, ds in enumerate(self.datasets):
116
+ indices[i] = ds.filter_indices_by_size(
117
+ indices[i],
118
+ self.max_positions[i],
119
+ )[0]
120
+ sub_batch_sampler = ds.batch_by_size(
121
+ indices[i],
122
+ max_tokens=self.max_tokens[i],
123
+ max_sentences=self.max_sentences[i],
124
+ required_batch_size_multiple=required_batch_size_multiple,
125
+ )
126
+ self.raw_sub_batch_samplers.append(sub_batch_sampler)
127
+
128
+ def get_batch_samplers(self, mult_ratios, required_batch_size_multiple, seed):
129
+ self.get_raw_batch_samplers(required_batch_size_multiple, seed)
130
+ batch_samplers = []
131
+ for i, _ in enumerate(self.datasets):
132
+ if i > 0:
133
+ sub_batch_sampler = [
134
+ [y + self.cumulative_sizes[i - 1] for y in x]
135
+ for x in self.raw_sub_batch_samplers[i]
136
+ ]
137
+ else:
138
+ sub_batch_sampler = list(self.raw_sub_batch_samplers[i])
139
+ smp_r = mult_ratios[i]
140
+ if smp_r != 1:
141
+ is_increase = "increased" if smp_r > 1 else "decreased"
142
+ logger.info(
143
+ "number of batch for the dataset {} is {} from {} to {}".format(
144
+ self.id_to_mode[i],
145
+ is_increase,
146
+ len(sub_batch_sampler),
147
+ int(len(sub_batch_sampler) * smp_r),
148
+ )
149
+ )
150
+ mul_samplers = []
151
+ for _ in range(math.floor(smp_r)):
152
+ mul_samplers = mul_samplers + sub_batch_sampler
153
+ if math.floor(smp_r) != smp_r:
154
+ with data_utils.numpy_seed(seed + self._cur_epoch):
155
+ np.random.shuffle(sub_batch_sampler)
156
+ smp_num = int(
157
+ (smp_r - math.floor(smp_r)) * len(sub_batch_sampler)
158
+ )
159
+ mul_samplers = mul_samplers + sub_batch_sampler[:smp_num]
160
+ sub_batch_sampler = mul_samplers
161
+ else:
162
+ logger.info(
163
+ "dataset {} batch number is {} ".format(
164
+ self.id_to_mode[i], len(sub_batch_sampler)
165
+ )
166
+ )
167
+ batch_samplers.append(sub_batch_sampler)
168
+
169
+ return batch_samplers
170
+
171
+
172
+ class LangPairMaskDataset(FairseqDataset):
173
+ def __init__(
174
+ self,
175
+ dataset: LanguagePairDataset,
176
+ src_eos: int,
177
+ src_bos: Optional[int] = None,
178
+ noise_id: Optional[int] = -1,
179
+ mask_ratio: Optional[float] = 0,
180
+ mask_type: Optional[str] = "random",
181
+ ):
182
+ self.dataset = dataset
183
+ self.src_eos = src_eos
184
+ self.src_bos = src_bos
185
+ self.noise_id = noise_id
186
+ self.mask_ratio = mask_ratio
187
+ self.mask_type = mask_type
188
+ assert mask_type in ("random", "tail")
189
+
190
+ @property
191
+ def src_sizes(self):
192
+ return self.dataset.src_sizes
193
+
194
+ @property
195
+ def tgt_sizes(self):
196
+ return self.dataset.tgt_sizes
197
+
198
+ @property
199
+ def sizes(self):
200
+ # dataset.sizes can be a dynamically computed sizes:
201
+ return self.dataset.sizes
202
+
203
+ def get_batch_shapes(self):
204
+ if hasattr(self.dataset, "get_batch_shapes"):
205
+ return self.dataset.get_batch_shapes()
206
+ return self.dataset.buckets
207
+
208
+ def num_tokens_vec(self, indices):
209
+ return self.dataset.num_tokens_vec(indices)
210
+
211
+ def __len__(self):
212
+ return len(self.dataset)
213
+
214
+ def num_tokens(self, index):
215
+ return self.dataset.num_tokens(index)
216
+
217
+ def size(self, index):
218
+ return self.dataset.size(index)
219
+
220
+ def ordered_indices(self):
221
+ return self.dataset.ordered_indices()
222
+
223
+ @property
224
+ def supports_prefetch(self):
225
+ return getattr(self.dataset, "supports_prefetch", False)
226
+
227
+ def prefetch(self, indices):
228
+ return self.dataset.prefetch(indices)
229
+
230
+ def mask_src_tokens(self, sample):
231
+ src_item = sample["source"]
232
+ mask = None
233
+ if self.mask_type == "random":
234
+ mask = torch.rand(len(src_item)).le(self.mask_ratio)
235
+ else:
236
+ mask = torch.ones(len(src_item))
237
+ mask[: int(len(src_item) * (1 - self.mask_ratio))] = 0
238
+ mask = mask.eq(1)
239
+ if src_item[0] == self.src_bos:
240
+ mask[0] = False
241
+ if src_item[-1] == self.src_eos:
242
+ mask[-1] = False
243
+ mask_src_item = src_item.masked_fill(mask, self.noise_id)
244
+ smp = {"id": sample["id"], "source": mask_src_item, "target": sample["target"]}
245
+ return smp
246
+
247
+ def __getitem__(self, index):
248
+ sample = self.dataset[index]
249
+ if self.mask_ratio > 0:
250
+ sample = self.mask_src_tokens(sample)
251
+ return sample
252
+
253
+ def collater(self, samples, pad_to_length=None):
254
+ return self.dataset.collater(samples, pad_to_length)
255
+
256
+
257
+ class FileAudioDatasetWrapper(FileAudioDataset):
258
+ def collater(self, samples):
259
+ samples = super().collater(samples)
260
+ if len(samples) == 0:
261
+ return {}
262
+ samples["net_input"]["src_tokens"] = samples["net_input"]["source"]
263
+ samples["net_input"]["prev_output_tokens"] = None
264
+ del samples["net_input"]["source"]
265
+ samples["net_input"]["src_lengths"] = None
266
+ samples["net_input"]["alignment"] = None
267
+ return samples
fairseq/fairseq/data/audio/raw_audio_dataset.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ import time
11
+ import io
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from .. import FairseqDataset
18
+ from ..data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes
19
+ from fairseq.data.audio.audio_utils import (
20
+ parse_path,
21
+ read_from_stored_zip,
22
+ is_sf_audio_data,
23
+ )
24
+ from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class RawAudioDataset(FairseqDataset):
31
+ def __init__(
32
+ self,
33
+ sample_rate,
34
+ max_sample_size=None,
35
+ min_sample_size=0,
36
+ shuffle=True,
37
+ pad=False,
38
+ normalize=False,
39
+ compute_mask=False,
40
+ feature_encoder_spec: str = "None",
41
+ mask_prob: float = 0.75,
42
+ mask_prob_adjust: float = 0,
43
+ mask_length: int = 1,
44
+ inverse_mask: bool = False,
45
+ require_same_masks: bool = True,
46
+ clone_batch: int = 1,
47
+ expand_adjacent: bool = False,
48
+ mask_dropout: float = 0,
49
+ non_overlapping: bool = False,
50
+ corpus_key=None,
51
+ ):
52
+ super().__init__()
53
+
54
+ self.sample_rate = sample_rate
55
+ self.sizes = []
56
+ self.max_sample_size = (
57
+ max_sample_size if max_sample_size is not None else sys.maxsize
58
+ )
59
+ self.min_sample_size = min_sample_size
60
+ self.pad = pad
61
+ self.shuffle = shuffle
62
+ self.normalize = normalize
63
+
64
+ self.is_compute_mask = compute_mask
65
+ self.feature_encoder_spec = eval(feature_encoder_spec)
66
+ self._features_size_map = {}
67
+ self.mask_prob = mask_prob
68
+ self.mask_prob_adjust = mask_prob_adjust
69
+ self.mask_length = mask_length
70
+ self.inverse_mask = inverse_mask
71
+ self.require_same_masks = require_same_masks
72
+ self.clone_batch = clone_batch
73
+ self.expand_adjacent = expand_adjacent
74
+ self.mask_dropout = mask_dropout
75
+ self.non_overlapping = non_overlapping
76
+ self.corpus_key = corpus_key
77
+
78
+ def __getitem__(self, index):
79
+ raise NotImplementedError()
80
+
81
+ def __len__(self):
82
+ return len(self.sizes)
83
+
84
+ def postprocess(self, feats, curr_sample_rate):
85
+ if feats.dim() == 2:
86
+ feats = feats.mean(-1)
87
+
88
+ if curr_sample_rate != self.sample_rate:
89
+ raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
90
+
91
+ assert feats.dim() == 1, feats.dim()
92
+
93
+ if self.normalize:
94
+ with torch.no_grad():
95
+ feats = F.layer_norm(feats, feats.shape)
96
+ return feats
97
+
98
+ def crop_to_max_size(self, t, target_size, dim=0):
99
+ size = t.size(dim)
100
+ diff = size - target_size
101
+ if diff <= 0:
102
+ return t
103
+
104
+ start = np.random.randint(0, diff + 1)
105
+ end = size - diff + start
106
+
107
+ slices = []
108
+ for d in range(dim):
109
+ slices.append(slice(None))
110
+ slices.append(slice(start, end))
111
+
112
+ return t[slices]
113
+
114
+ @staticmethod
115
+ def _bucket_tensor(tensor, num_pad, value):
116
+ return F.pad(tensor, (0, num_pad), value=value)
117
+
118
+ def collater(self, samples):
119
+ samples = [s for s in samples if s["source"] is not None]
120
+ if len(samples) == 0:
121
+ return {}
122
+
123
+ sources = [s["source"] for s in samples]
124
+ sizes = [len(s) for s in sources]
125
+
126
+ if self.pad:
127
+ target_size = min(max(sizes), self.max_sample_size)
128
+ else:
129
+ target_size = min(min(sizes), self.max_sample_size)
130
+
131
+ collated_sources = sources[0].new_zeros(len(sources), target_size)
132
+ padding_mask = (
133
+ torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
134
+ )
135
+ for i, (source, size) in enumerate(zip(sources, sizes)):
136
+ diff = size - target_size
137
+ if diff == 0:
138
+ collated_sources[i] = source
139
+ elif diff < 0:
140
+ assert self.pad
141
+ collated_sources[i] = torch.cat(
142
+ [source, source.new_full((-diff,), 0.0)]
143
+ )
144
+ padding_mask[i, diff:] = True
145
+ else:
146
+ collated_sources[i] = self.crop_to_max_size(source, target_size)
147
+
148
+ input = {"source": collated_sources}
149
+ if self.corpus_key is not None:
150
+ input["corpus_key"] = [self.corpus_key] * len(sources)
151
+ out = {"id": torch.LongTensor([s["id"] for s in samples])}
152
+ if self.pad:
153
+ input["padding_mask"] = padding_mask
154
+
155
+ if hasattr(self, "num_buckets") and self.num_buckets > 0:
156
+ assert self.pad, "Cannot bucket without padding first."
157
+ bucket = max(self._bucketed_sizes[s["id"]] for s in samples)
158
+ num_pad = bucket - collated_sources.size(-1)
159
+ if num_pad:
160
+ input["source"] = self._bucket_tensor(collated_sources, num_pad, 0)
161
+ input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True)
162
+
163
+ if "precomputed_mask" in samples[0]:
164
+ target_size = self._get_mask_indices_dims(target_size)
165
+ collated_mask = torch.cat(
166
+ [
167
+ self.crop_to_max_size(s["precomputed_mask"], target_size, dim=1)
168
+ for s in samples
169
+ ],
170
+ dim=0,
171
+ )
172
+ input["precomputed_mask"] = collated_mask
173
+
174
+ out["net_input"] = input
175
+ return out
176
+
177
+ def _get_mask_indices_dims(self, size, padding=0, dilation=1):
178
+ if size not in self.feature_encoder_spec:
179
+ L_in = size
180
+ for (_, kernel_size, stride) in self.feature_encoder_spec:
181
+ L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
182
+ L_out = 1 + L_out // stride
183
+ L_in = L_out
184
+ self._features_size_map[size] = L_out
185
+ return self._features_size_map[size]
186
+
187
+ def num_tokens(self, index):
188
+ return self.size(index)
189
+
190
+ def size(self, index):
191
+ """Return an example's size as a float or tuple. This value is used when
192
+ filtering a dataset with ``--max-positions``."""
193
+ if self.pad:
194
+ return self.sizes[index]
195
+ return min(self.sizes[index], self.max_sample_size)
196
+
197
+ def ordered_indices(self):
198
+ """Return an ordered list of indices. Batches will be constructed based
199
+ on this order."""
200
+
201
+ if self.shuffle:
202
+ order = [np.random.permutation(len(self))]
203
+ order.append(
204
+ np.minimum(
205
+ np.array(self.sizes),
206
+ self.max_sample_size,
207
+ )
208
+ )
209
+ return np.lexsort(order)[::-1]
210
+ else:
211
+ return np.arange(len(self))
212
+
213
+ def set_bucket_info(self, num_buckets):
214
+ self.num_buckets = num_buckets
215
+ if self.num_buckets > 0:
216
+ self._collated_sizes = np.minimum(
217
+ np.array(self.sizes),
218
+ self.max_sample_size,
219
+ )
220
+ self.buckets = get_buckets(
221
+ self._collated_sizes,
222
+ self.num_buckets,
223
+ )
224
+ self._bucketed_sizes = get_bucketed_sizes(
225
+ self._collated_sizes, self.buckets
226
+ )
227
+ logger.info(
228
+ f"{len(self.buckets)} bucket(s) for the audio dataset: "
229
+ f"{self.buckets}"
230
+ )
231
+
232
+ def filter_indices_by_size(self, indices, max_sizes):
233
+ return indices, []
234
+
235
+
236
+ class FileAudioDataset(RawAudioDataset):
237
+ def __init__(
238
+ self,
239
+ manifest_path,
240
+ sample_rate,
241
+ max_sample_size=None,
242
+ min_sample_size=0,
243
+ shuffle=True,
244
+ pad=False,
245
+ normalize=False,
246
+ num_buckets=0,
247
+ compute_mask=False,
248
+ text_compression_level=TextCompressionLevel.none,
249
+ **mask_compute_kwargs,
250
+ ):
251
+ super().__init__(
252
+ sample_rate=sample_rate,
253
+ max_sample_size=max_sample_size,
254
+ min_sample_size=min_sample_size,
255
+ shuffle=shuffle,
256
+ pad=pad,
257
+ normalize=normalize,
258
+ compute_mask=compute_mask,
259
+ **mask_compute_kwargs,
260
+ )
261
+
262
+ self.text_compressor = TextCompressor(level=text_compression_level)
263
+
264
+ skipped = 0
265
+ self.fnames = []
266
+ sizes = []
267
+ self.skipped_indices = set()
268
+
269
+ with open(manifest_path, "r") as f:
270
+ self.root_dir = f.readline().strip()
271
+ for i, line in enumerate(f):
272
+ items = line.strip().split("\t")
273
+ assert len(items) == 2, line
274
+ sz = int(items[1])
275
+ if min_sample_size is not None and sz < min_sample_size:
276
+ skipped += 1
277
+ self.skipped_indices.add(i)
278
+ continue
279
+ self.fnames.append(self.text_compressor.compress(items[0]))
280
+ sizes.append(sz)
281
+ logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
282
+
283
+ self.sizes = np.array(sizes, dtype=np.int64)
284
+
285
+ try:
286
+ import pyarrow
287
+
288
+ self.fnames = pyarrow.array(self.fnames)
289
+ except:
290
+ logger.debug(
291
+ "Could not create a pyarrow array. Please install pyarrow for better performance"
292
+ )
293
+ pass
294
+
295
+ self.set_bucket_info(num_buckets)
296
+
297
+ def __getitem__(self, index):
298
+ import soundfile as sf
299
+
300
+ fn = self.fnames[index]
301
+ fn = fn if isinstance(self.fnames, list) else fn.as_py()
302
+ fn = self.text_compressor.decompress(fn)
303
+ path_or_fp = os.path.join(self.root_dir, fn)
304
+ _path, slice_ptr = parse_path(path_or_fp)
305
+ if len(slice_ptr) == 2:
306
+ byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1])
307
+ assert is_sf_audio_data(byte_data)
308
+ path_or_fp = io.BytesIO(byte_data)
309
+
310
+ retry = 3
311
+ wav = None
312
+ for i in range(retry):
313
+ try:
314
+ wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32")
315
+ break
316
+ except Exception as e:
317
+ logger.warning(
318
+ f"Failed to read {path_or_fp}: {e}. Sleeping for {1 * i}"
319
+ )
320
+ time.sleep(1 * i)
321
+
322
+ if wav is None:
323
+ raise Exception(f"Failed to load {path_or_fp}")
324
+
325
+ feats = torch.from_numpy(wav).float()
326
+ feats = self.postprocess(feats, curr_sample_rate)
327
+
328
+ v = {"id": index, "source": feats}
329
+
330
+ if self.is_compute_mask:
331
+ T = self._get_mask_indices_dims(feats.size(-1))
332
+ mask = compute_block_mask_1d(
333
+ shape=(self.clone_batch, T),
334
+ mask_prob=self.mask_prob,
335
+ mask_length=self.mask_length,
336
+ mask_prob_adjust=self.mask_prob_adjust,
337
+ inverse_mask=self.inverse_mask,
338
+ require_same_masks=True,
339
+ expand_adjcent=self.expand_adjacent,
340
+ mask_dropout=self.mask_dropout,
341
+ non_overlapping=self.non_overlapping,
342
+ )
343
+
344
+ v["precomputed_mask"] = mask
345
+
346
+ return v
347
+
348
+
349
+ class BinarizedAudioDataset(RawAudioDataset):
350
+ def __init__(
351
+ self,
352
+ data_dir,
353
+ split,
354
+ sample_rate,
355
+ max_sample_size=None,
356
+ min_sample_size=0,
357
+ shuffle=True,
358
+ pad=False,
359
+ normalize=False,
360
+ num_buckets=0,
361
+ compute_mask=False,
362
+ **mask_compute_kwargs,
363
+ ):
364
+ super().__init__(
365
+ sample_rate=sample_rate,
366
+ max_sample_size=max_sample_size,
367
+ min_sample_size=min_sample_size,
368
+ shuffle=shuffle,
369
+ pad=pad,
370
+ normalize=normalize,
371
+ compute_mask=compute_mask,
372
+ **mask_compute_kwargs,
373
+ )
374
+
375
+ from fairseq.data import data_utils, Dictionary
376
+
377
+ self.fnames_dict = Dictionary.load(os.path.join(data_dir, "dict.txt"))
378
+
379
+ root_path = os.path.join(data_dir, f"{split}.root")
380
+ if os.path.exists(root_path):
381
+ with open(root_path, "r") as f:
382
+ self.root_dir = next(f).strip()
383
+ else:
384
+ self.root_dir = None
385
+
386
+ fnames_path = os.path.join(data_dir, split)
387
+ self.fnames = data_utils.load_indexed_dataset(fnames_path, self.fnames_dict)
388
+ lengths_path = os.path.join(data_dir, f"{split}.lengths")
389
+
390
+ with open(lengths_path, "r") as f:
391
+ for line in f:
392
+ sz = int(line.rstrip())
393
+ assert (
394
+ sz >= min_sample_size
395
+ ), f"Min sample size is not supported for binarized dataset, but found a sample with size {sz}"
396
+ self.sizes.append(sz)
397
+
398
+ self.sizes = np.array(self.sizes, dtype=np.int64)
399
+
400
+ self.set_bucket_info(num_buckets)
401
+ logger.info(f"loaded {len(self.fnames)} samples")
402
+
403
+ def __getitem__(self, index):
404
+ import soundfile as sf
405
+
406
+ fname = self.fnames_dict.string(self.fnames[index], separator="")
407
+ if self.root_dir:
408
+ fname = os.path.join(self.root_dir, fname)
409
+
410
+ wav, curr_sample_rate = sf.read(fname)
411
+ feats = torch.from_numpy(wav).float()
412
+ feats = self.postprocess(feats, curr_sample_rate)
413
+ v = {"id": index, "source": feats}
414
+
415
+ if self.is_compute_mask:
416
+ T = self._get_mask_indices_dims(feats.size(-1))
417
+ mask = compute_block_mask_1d(
418
+ shape=(self.clone_batch, T),
419
+ mask_prob=self.mask_prob,
420
+ mask_length=self.mask_length,
421
+ mask_prob_adjust=self.mask_prob_adjust,
422
+ inverse_mask=self.inverse_mask,
423
+ require_same_masks=True,
424
+ expand_adjcent=self.expand_adjacent,
425
+ mask_dropout=self.mask_dropout,
426
+ non_overlapping=self.non_overlapping,
427
+ )
428
+
429
+ v["precomputed_mask"] = mask
430
+
431
+ return v
fairseq/fairseq/data/audio/speech_to_speech_dataset.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ import torch
12
+
13
+ from fairseq.data import ConcatDataset, Dictionary
14
+ from fairseq.data import data_utils as fairseq_data_utils
15
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
16
+ from fairseq.data.audio.data_cfg import S2SDataConfig
17
+ from fairseq.data.audio.speech_to_text_dataset import (
18
+ SpeechToTextDataset,
19
+ SpeechToTextDatasetCreator,
20
+ TextTargetMultitaskData,
21
+ _collate_frames,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @dataclass
28
+ class SpeechToSpeechDatasetItem(object):
29
+ index: int
30
+ source: torch.Tensor
31
+ target: Optional[torch.Tensor] = None
32
+ target_speaker: Optional[torch.Tensor] = None
33
+ tgt_lang_tag: Optional[int] = None
34
+
35
+
36
+ class SpeechToSpeechDataset(SpeechToTextDataset):
37
+ def __init__(
38
+ self,
39
+ split: str,
40
+ is_train_split: bool,
41
+ data_cfg: S2SDataConfig,
42
+ src_audio_paths: List[str],
43
+ src_n_frames: List[int],
44
+ tgt_audio_paths: List[str],
45
+ tgt_n_frames: List[int],
46
+ src_langs: Optional[List[str]] = None,
47
+ tgt_langs: Optional[List[str]] = None,
48
+ ids: Optional[List[str]] = None,
49
+ target_is_code: bool = False,
50
+ tgt_dict: Dictionary = None,
51
+ n_frames_per_step: int = 1,
52
+ ):
53
+ tgt_texts = tgt_audio_paths if target_is_code else None
54
+ super().__init__(
55
+ split=split,
56
+ is_train_split=is_train_split,
57
+ cfg=data_cfg,
58
+ audio_paths=src_audio_paths,
59
+ n_frames=src_n_frames,
60
+ ids=ids,
61
+ tgt_dict=tgt_dict,
62
+ tgt_texts=tgt_texts,
63
+ src_langs=src_langs,
64
+ tgt_langs=tgt_langs,
65
+ n_frames_per_step=n_frames_per_step,
66
+ )
67
+
68
+ self.tgt_audio_paths = tgt_audio_paths
69
+ self.tgt_lens = [t // self.n_frames_per_step for t in tgt_n_frames]
70
+
71
+ assert not target_is_code or tgt_dict is not None
72
+ self.target_is_code = target_is_code
73
+
74
+ assert len(tgt_audio_paths) == self.n_samples
75
+ assert len(tgt_n_frames) == self.n_samples
76
+
77
+ self.tgt_speakers = None
78
+ if self.cfg.target_speaker_embed:
79
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(
80
+ self.cfg.target_speaker_embed, split
81
+ )
82
+ spk_emb_dict = {s["id"]: s["speaker_embed"] for s in samples}
83
+ self.tgt_speakers = [spk_emb_dict[id] for id in self.ids]
84
+ assert len(self.tgt_speakers) == self.n_samples
85
+
86
+ logger.info(self.__repr__())
87
+
88
+ def pack_units(self, input: torch.Tensor) -> torch.Tensor:
89
+ if self.n_frames_per_step <= 1:
90
+ return input
91
+
92
+ offset = 4
93
+ vocab_size = (
94
+ len(self.tgt_dict) - offset
95
+ ) # remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
96
+
97
+ assert input.dim() == 1
98
+ stacked_input = (
99
+ input[:-1].view(-1, self.n_frames_per_step) - offset
100
+ ) # remove <eos>
101
+ scale = [
102
+ pow(vocab_size, self.n_frames_per_step - 1 - i)
103
+ for i in range(self.n_frames_per_step)
104
+ ]
105
+ scale = torch.LongTensor(scale).squeeze(0)
106
+ res = input.new((len(input) - 1) // self.n_frames_per_step + 1).fill_(input[-1])
107
+ res[:-1] = (stacked_input * scale).sum(dim=1) + offset
108
+
109
+ return res
110
+
111
+ def __getitem__(self, index: int) -> SpeechToSpeechDatasetItem:
112
+ source = self._get_source_audio(index)
113
+
114
+ tgt_lang_tag = None
115
+ if self.cfg.prepend_tgt_lang_tag_as_bos:
116
+ # prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
117
+ tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
118
+
119
+ if not self.target_is_code:
120
+ target = get_features_or_waveform(self.tgt_audio_paths[index])
121
+ target = torch.from_numpy(target).float()
122
+ target = self.pack_frames(target)
123
+ else:
124
+ target = self.tgt_dict.encode_line(
125
+ self.tgt_audio_paths[index],
126
+ add_if_not_exist=False,
127
+ append_eos=True,
128
+ ).long()
129
+ if self.n_frames_per_step > 1:
130
+ n_tgt_frame = target.size(0) - 1 # exclude <eos>
131
+ keep_n_tgt_frame = n_tgt_frame - n_tgt_frame % self.n_frames_per_step
132
+ target = torch.cat(
133
+ (
134
+ target[:keep_n_tgt_frame],
135
+ target.new_full((1,), self.tgt_dict.eos()),
136
+ ),
137
+ dim=0,
138
+ )
139
+
140
+ if self.tgt_speakers:
141
+ tgt_spk = get_features_or_waveform(self.tgt_speakers[index])
142
+ tgt_spk = torch.from_numpy(tgt_spk).float()
143
+ else:
144
+ tgt_spk = torch.FloatTensor([])
145
+
146
+ return SpeechToSpeechDatasetItem(
147
+ index=index,
148
+ source=source,
149
+ target=target,
150
+ target_speaker=tgt_spk,
151
+ tgt_lang_tag=tgt_lang_tag,
152
+ )
153
+
154
+ def _collate_target(self, samples: List[SpeechToSpeechDatasetItem]) -> torch.Tensor:
155
+ if self.target_is_code:
156
+ target = fairseq_data_utils.collate_tokens(
157
+ [x.target for x in samples],
158
+ self.tgt_dict.pad(),
159
+ self.tgt_dict.eos(),
160
+ left_pad=False,
161
+ move_eos_to_beginning=False,
162
+ )
163
+ # convert stacked units to a single id
164
+ pack_targets = [self.pack_units(x.target) for x in samples]
165
+ prev_output_tokens = fairseq_data_utils.collate_tokens(
166
+ pack_targets,
167
+ self.tgt_dict.pad(),
168
+ self.tgt_dict.eos(),
169
+ left_pad=False,
170
+ move_eos_to_beginning=True,
171
+ )
172
+ target_lengths = torch.tensor(
173
+ [x.size(0) for x in pack_targets], dtype=torch.long
174
+ )
175
+ else:
176
+ target = _collate_frames([x.target for x in samples], is_audio_input=False)
177
+ bsz, _, d = target.size()
178
+ prev_output_tokens = torch.cat(
179
+ (target.new_full((bsz, 1, d), 0.0), target[:, :-1, :]), dim=1
180
+ )
181
+ target_lengths = torch.tensor(
182
+ [x.target.size(0) for x in samples], dtype=torch.long
183
+ )
184
+
185
+ return target, prev_output_tokens, target_lengths
186
+
187
+ def collater(
188
+ self, samples: List[SpeechToSpeechDatasetItem], return_order: bool = False
189
+ ) -> Dict:
190
+ if len(samples) == 0:
191
+ return {}
192
+ indices = torch.tensor([x.index for x in samples], dtype=torch.long)
193
+ frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input)
194
+ # sort samples by descending number of frames
195
+ n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long)
196
+ n_frames, order = n_frames.sort(descending=True)
197
+ indices = indices.index_select(0, order)
198
+ frames = frames.index_select(0, order)
199
+
200
+ target, prev_output_tokens, target_lengths = self._collate_target(samples)
201
+ target = target.index_select(0, order)
202
+ target_lengths = target_lengths.index_select(0, order)
203
+ prev_output_tokens = prev_output_tokens.index_select(0, order)
204
+ ntokens = sum(x.target.size(0) for x in samples)
205
+
206
+ tgt_speakers = None
207
+ if self.cfg.target_speaker_embed:
208
+ tgt_speakers = _collate_frames(
209
+ [x.target_speaker for x in samples], is_audio_input=True
210
+ ).index_select(0, order)
211
+
212
+ net_input = {
213
+ "src_tokens": frames,
214
+ "src_lengths": n_frames,
215
+ "prev_output_tokens": prev_output_tokens,
216
+ "tgt_speaker": tgt_speakers, # TODO: unify "speaker" and "tgt_speaker"
217
+ }
218
+ if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
219
+ for i in range(len(samples)):
220
+ net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
221
+ out = {
222
+ "id": indices,
223
+ "net_input": net_input,
224
+ "speaker": tgt_speakers, # to support Tacotron2 loss for speech-to-spectrogram model
225
+ "target": target,
226
+ "target_lengths": target_lengths,
227
+ "ntokens": ntokens,
228
+ "nsentences": len(samples),
229
+ }
230
+ if return_order:
231
+ out["order"] = order
232
+ return out
233
+
234
+
235
+ class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
236
+ def __init__(self, **kwargs):
237
+ super().__init__(**kwargs)
238
+ self.multitask_data = {}
239
+
240
+ def add_multitask_dataset(self, task_name, task_data):
241
+ self.multitask_data[task_name] = task_data
242
+
243
+ def __getitem__(
244
+ self, index: int
245
+ ) -> Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]:
246
+ s2s_data = super().__getitem__(index)
247
+
248
+ multitask_target = {}
249
+ sample_id = self.ids[index]
250
+ tgt_lang = self.tgt_langs[index]
251
+ for task_name, task_dataset in self.multitask_data.items():
252
+ multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
253
+
254
+ return s2s_data, multitask_target
255
+
256
+ def collater(
257
+ self, samples: List[Tuple[SpeechToSpeechDatasetItem, Dict[str, torch.Tensor]]]
258
+ ) -> Dict:
259
+ if len(samples) == 0:
260
+ return {}
261
+
262
+ out = super().collater([s for s, _ in samples], return_order=True)
263
+ order = out["order"]
264
+ del out["order"]
265
+
266
+ for task_name, task_dataset in self.multitask_data.items():
267
+ if "multitask" not in out:
268
+ out["multitask"] = {}
269
+ d = [s[task_name] for _, s in samples]
270
+ task_target = task_dataset.collater(d)
271
+ out["multitask"][task_name] = {
272
+ "target": task_target["target"].index_select(0, order),
273
+ "target_lengths": task_target["target_lengths"].index_select(0, order),
274
+ "ntokens": task_target["ntokens"],
275
+ }
276
+ out["multitask"][task_name]["net_input"] = {
277
+ "prev_output_tokens": task_target["prev_output_tokens"].index_select(
278
+ 0, order
279
+ ),
280
+ }
281
+
282
+ return out
283
+
284
+
285
+ class SpeechToSpeechDatasetCreator(object):
286
+ # mandatory columns
287
+ KEY_ID, KEY_SRC_AUDIO, KEY_SRC_N_FRAMES = "id", "src_audio", "src_n_frames"
288
+ KEY_TGT_AUDIO, KEY_TGT_N_FRAMES = "tgt_audio", "tgt_n_frames"
289
+ # optional columns
290
+ KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
291
+ # default values
292
+ DEFAULT_LANG = ""
293
+
294
+ @classmethod
295
+ def _from_list(
296
+ cls,
297
+ split_name: str,
298
+ is_train_split,
299
+ samples: List[Dict],
300
+ data_cfg: S2SDataConfig,
301
+ target_is_code: bool = False,
302
+ tgt_dict: Dictionary = None,
303
+ n_frames_per_step: int = 1,
304
+ multitask: Optional[Dict] = None,
305
+ ) -> SpeechToSpeechDataset:
306
+ audio_root = Path(data_cfg.audio_root)
307
+ ids = [s[cls.KEY_ID] for s in samples]
308
+ src_audio_paths = [
309
+ (audio_root / s[cls.KEY_SRC_AUDIO]).as_posix() for s in samples
310
+ ]
311
+ tgt_audio_paths = [
312
+ s[cls.KEY_TGT_AUDIO]
313
+ if target_is_code
314
+ else (audio_root / s[cls.KEY_TGT_AUDIO]).as_posix()
315
+ for s in samples
316
+ ]
317
+ src_n_frames = [int(s[cls.KEY_SRC_N_FRAMES]) for s in samples]
318
+ tgt_n_frames = [int(s[cls.KEY_TGT_N_FRAMES]) for s in samples]
319
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
320
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
321
+
322
+ has_multitask = multitask is not None and len(multitask.keys()) > 0
323
+ dataset_cls = (
324
+ SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
325
+ )
326
+
327
+ ds = dataset_cls(
328
+ split=split_name,
329
+ is_train_split=is_train_split,
330
+ data_cfg=data_cfg,
331
+ src_audio_paths=src_audio_paths,
332
+ src_n_frames=src_n_frames,
333
+ tgt_audio_paths=tgt_audio_paths,
334
+ tgt_n_frames=tgt_n_frames,
335
+ src_langs=src_langs,
336
+ tgt_langs=tgt_langs,
337
+ ids=ids,
338
+ target_is_code=target_is_code,
339
+ tgt_dict=tgt_dict,
340
+ n_frames_per_step=n_frames_per_step,
341
+ )
342
+
343
+ if has_multitask:
344
+ for task_name, task_obj in multitask.items():
345
+ task_data = TextTargetMultitaskData(
346
+ task_obj.args, split_name, task_obj.target_dictionary
347
+ )
348
+ ds.add_multitask_dataset(task_name, task_data)
349
+ return ds
350
+
351
+ @classmethod
352
+ def from_tsv(
353
+ cls,
354
+ root: str,
355
+ data_cfg: S2SDataConfig,
356
+ splits: str,
357
+ is_train_split: bool,
358
+ epoch: int,
359
+ seed: int,
360
+ target_is_code: bool = False,
361
+ tgt_dict: Dictionary = None,
362
+ n_frames_per_step: int = 1,
363
+ multitask: Optional[Dict] = None,
364
+ ) -> SpeechToSpeechDataset:
365
+ datasets = []
366
+ for split in splits.split(","):
367
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
368
+ ds = cls._from_list(
369
+ split_name=split,
370
+ is_train_split=is_train_split,
371
+ samples=samples,
372
+ data_cfg=data_cfg,
373
+ target_is_code=target_is_code,
374
+ tgt_dict=tgt_dict,
375
+ n_frames_per_step=n_frames_per_step,
376
+ multitask=multitask,
377
+ )
378
+ datasets.append(ds)
379
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
fairseq/fairseq/data/audio/speech_to_text_joint_dataset.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Dict, List, NamedTuple, Optional
9
+
10
+ import torch
11
+
12
+ from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
13
+ from fairseq.data import data_utils as fairseq_data_utils
14
+ from fairseq.data.audio.speech_to_text_dataset import (
15
+ S2TDataConfig,
16
+ SpeechToTextDataset,
17
+ SpeechToTextDatasetCreator,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class S2TJointDataConfig(S2TDataConfig):
24
+ """Wrapper class for data config YAML"""
25
+
26
+ @property
27
+ def src_vocab_filename(self):
28
+ """fairseq vocabulary file under data root"""
29
+ return self.config.get("src_vocab_filename", "src_dict.txt")
30
+
31
+ @property
32
+ def src_pre_tokenizer(self) -> Dict:
33
+ """Pre-tokenizer to apply before subword tokenization. Returning
34
+ a dictionary with `tokenizer` providing the tokenizer name and
35
+ the other items providing the tokenizer-specific arguments.
36
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
37
+ return self.config.get("src_pre_tokenizer", {"tokenizer": None})
38
+
39
+ @property
40
+ def src_bpe_tokenizer(self) -> Dict:
41
+ """Subword tokenizer to apply on source text after pre-tokenization.
42
+ Returning a dictionary with `bpe` providing the tokenizer name and
43
+ the other items providing the tokenizer-specific arguments.
44
+ Tokenizers are defined in `fairseq.data.encoders.*`"""
45
+ return self.config.get("src_bpe_tokenizer", {"bpe": None})
46
+
47
+ @property
48
+ def prepend_tgt_lang_tag_no_change(self) -> bool:
49
+ """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
50
+ to-many multilingual setting). No change needed during inference.
51
+ This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
52
+ """
53
+ value = self.config.get("prepend_tgt_lang_tag_no_change", None)
54
+ if value is None:
55
+ return self.config.get("prepend_tgt_lang_tag_as_bos", False)
56
+ return value
57
+
58
+ @property
59
+ def sampling_text_alpha(self):
60
+ """Hyper-parameter alpha = 1/T for temperature-based resampling. (text
61
+ input only) (alpha = 1 for no resampling)"""
62
+ return self.config.get("sampling_text_alpha", 1.0)
63
+
64
+
65
+ class SpeechToTextJointDatasetItem(NamedTuple):
66
+ index: int
67
+ source: torch.Tensor
68
+ target: Optional[torch.Tensor] = None
69
+ src_txt_tokens: Optional[torch.Tensor] = None
70
+ tgt_lang_tag: Optional[int] = None
71
+ src_lang_tag: Optional[int] = None
72
+ tgt_alignment: Optional[torch.Tensor] = None
73
+
74
+
75
+ # use_src_lang_id:
76
+ # 0: don't use src_lang_id
77
+ # 1: attach src_lang_id to the src_txt_tokens as eos
78
+ class SpeechToTextJointDataset(SpeechToTextDataset):
79
+ def __init__(
80
+ self,
81
+ split: str,
82
+ is_train_split: bool,
83
+ cfg: S2TJointDataConfig,
84
+ audio_paths: List[str],
85
+ n_frames: List[int],
86
+ src_texts: Optional[List[str]] = None,
87
+ tgt_texts: Optional[List[str]] = None,
88
+ speakers: Optional[List[str]] = None,
89
+ src_langs: Optional[List[str]] = None,
90
+ tgt_langs: Optional[List[str]] = None,
91
+ ids: Optional[List[str]] = None,
92
+ tgt_dict: Optional[Dictionary] = None,
93
+ src_dict: Optional[Dictionary] = None,
94
+ pre_tokenizer=None,
95
+ bpe_tokenizer=None,
96
+ src_pre_tokenizer=None,
97
+ src_bpe_tokenizer=None,
98
+ append_eos: Optional[bool] = True,
99
+ alignment: Optional[List[str]] = None,
100
+ use_src_lang_id: Optional[int] = 0,
101
+ ):
102
+ super().__init__(
103
+ split,
104
+ is_train_split,
105
+ cfg,
106
+ audio_paths,
107
+ n_frames,
108
+ src_texts=src_texts,
109
+ tgt_texts=tgt_texts,
110
+ speakers=speakers,
111
+ src_langs=src_langs,
112
+ tgt_langs=tgt_langs,
113
+ ids=ids,
114
+ tgt_dict=tgt_dict,
115
+ pre_tokenizer=pre_tokenizer,
116
+ bpe_tokenizer=bpe_tokenizer,
117
+ append_eos=append_eos,
118
+ )
119
+
120
+ self.src_dict = src_dict
121
+ self.src_pre_tokenizer = src_pre_tokenizer
122
+ self.src_bpe_tokenizer = src_bpe_tokenizer
123
+ self.alignment = None
124
+ self.use_src_lang_id = use_src_lang_id
125
+ if alignment is not None:
126
+ self.alignment = [
127
+ [float(s) for s in sample.split()] for sample in alignment
128
+ ]
129
+
130
+ def get_tokenized_src_text(self, index: int):
131
+ text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index])
132
+ text = self.tokenize(self.src_bpe_tokenizer, text)
133
+ return text
134
+
135
+ def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem:
136
+ s2t_dataset_item = super().__getitem__(index)
137
+ src_tokens = None
138
+ src_lang_tag = None
139
+ if self.src_texts is not None and self.src_dict is not None:
140
+ src_tokens = self.get_tokenized_src_text(index)
141
+ src_tokens = self.src_dict.encode_line(
142
+ src_tokens, add_if_not_exist=False, append_eos=True
143
+ ).long()
144
+ if self.use_src_lang_id > 0:
145
+ src_lang_tag = self.get_lang_tag_idx(
146
+ self.src_langs[index], self.src_dict
147
+ )
148
+ tgt_lang_tag = None
149
+ if self.cfg.prepend_tgt_lang_tag_no_change:
150
+ # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
151
+ tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict)
152
+ ali = None
153
+ if self.alignment is not None:
154
+ ali = torch.Tensor(self.alignment[index]).float()
155
+
156
+ return SpeechToTextJointDatasetItem(
157
+ index=index,
158
+ source=s2t_dataset_item.source,
159
+ target=s2t_dataset_item.target,
160
+ src_txt_tokens=src_tokens,
161
+ tgt_lang_tag=tgt_lang_tag,
162
+ src_lang_tag=src_lang_tag,
163
+ tgt_alignment=ali,
164
+ )
165
+
166
+ def __len__(self):
167
+ return self.n_samples
168
+
169
+ def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict:
170
+ s2t_out = super().collater(samples, return_order=True)
171
+ if s2t_out == {}:
172
+ return s2t_out
173
+ net_input, order = s2t_out["net_input"], s2t_out["order"]
174
+
175
+ if self.src_texts is not None and self.src_dict is not None:
176
+ src_txt_tokens = fairseq_data_utils.collate_tokens(
177
+ [x.src_txt_tokens for x in samples],
178
+ self.src_dict.pad(),
179
+ self.src_dict.eos(),
180
+ left_pad=False,
181
+ move_eos_to_beginning=False,
182
+ )
183
+ src_txt_lengths = torch.tensor(
184
+ [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long
185
+ )
186
+ if self.use_src_lang_id > 0:
187
+ src_lang_idxs = torch.tensor(
188
+ [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype
189
+ )
190
+ if self.use_src_lang_id == 1: # replace eos with lang_id
191
+ eos_idx = src_txt_lengths - 1
192
+ src_txt_tokens.scatter_(
193
+ 1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1)
194
+ )
195
+ else:
196
+ raise NotImplementedError("Implementation is required")
197
+
198
+ src_txt_tokens = src_txt_tokens.index_select(0, order)
199
+ src_txt_lengths = src_txt_lengths.index_select(0, order)
200
+ net_input["src_txt_tokens"] = src_txt_tokens
201
+ net_input["src_txt_lengths"] = src_txt_lengths
202
+
203
+ net_input["alignment"] = None
204
+ if self.alignment is not None:
205
+ max_len = max([s.tgt_alignment.size(0) for s in samples])
206
+ alignment = torch.ones(len(samples), max_len).float()
207
+ for i, s in enumerate(samples):
208
+ cur_len = s.tgt_alignment.size(0)
209
+ alignment[i][:cur_len].copy_(s.tgt_alignment)
210
+ net_input["alignment"] = alignment.index_select(0, order)
211
+
212
+ if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None:
213
+ for i in range(len(samples)):
214
+ net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag
215
+
216
+ out = {
217
+ "id": s2t_out["id"],
218
+ "net_input": net_input,
219
+ "target": s2t_out["target"],
220
+ "target_lengths": s2t_out["target_lengths"],
221
+ "ntokens": s2t_out["ntokens"],
222
+ "nsentences": len(samples),
223
+ }
224
+ return out
225
+
226
+
227
+ class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator):
228
+ KEY_ALIGN = "align"
229
+
230
+ @classmethod
231
+ def _from_list(
232
+ cls,
233
+ split_name: str,
234
+ is_train_split,
235
+ samples: List[Dict],
236
+ cfg: S2TJointDataConfig,
237
+ tgt_dict,
238
+ src_dict,
239
+ pre_tokenizer,
240
+ bpe_tokenizer,
241
+ src_pre_tokenizer,
242
+ src_bpe_tokenizer,
243
+ append_eos,
244
+ use_src_lang_id,
245
+ ) -> SpeechToTextJointDataset:
246
+ audio_root = Path(cfg.audio_root)
247
+ ids = [s[cls.KEY_ID] for s in samples]
248
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
249
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
250
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
251
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
252
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
253
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
254
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
255
+ tgt_alignment = None
256
+ if cls.KEY_ALIGN in samples[0].keys():
257
+ tgt_alignment = [s[cls.KEY_ALIGN] for s in samples]
258
+ return SpeechToTextJointDataset(
259
+ split_name,
260
+ is_train_split,
261
+ cfg,
262
+ audio_paths,
263
+ n_frames,
264
+ src_texts=src_texts,
265
+ tgt_texts=tgt_texts,
266
+ speakers=speakers,
267
+ src_langs=src_langs,
268
+ tgt_langs=tgt_langs,
269
+ ids=ids,
270
+ tgt_dict=tgt_dict,
271
+ src_dict=src_dict,
272
+ pre_tokenizer=pre_tokenizer,
273
+ bpe_tokenizer=bpe_tokenizer,
274
+ src_pre_tokenizer=src_pre_tokenizer,
275
+ src_bpe_tokenizer=src_bpe_tokenizer,
276
+ append_eos=append_eos,
277
+ alignment=tgt_alignment,
278
+ use_src_lang_id=use_src_lang_id,
279
+ )
280
+
281
+ @classmethod
282
+ def _from_tsv(
283
+ cls,
284
+ root: str,
285
+ cfg: S2TJointDataConfig,
286
+ split: str,
287
+ tgt_dict,
288
+ src_dict,
289
+ is_train_split: bool,
290
+ pre_tokenizer,
291
+ bpe_tokenizer,
292
+ src_pre_tokenizer,
293
+ src_bpe_tokenizer,
294
+ append_eos: bool,
295
+ use_src_lang_id: int,
296
+ ) -> SpeechToTextJointDataset:
297
+ samples = cls._load_samples_from_tsv(root, split)
298
+ return cls._from_list(
299
+ split,
300
+ is_train_split,
301
+ samples,
302
+ cfg,
303
+ tgt_dict,
304
+ src_dict,
305
+ pre_tokenizer,
306
+ bpe_tokenizer,
307
+ src_pre_tokenizer,
308
+ src_bpe_tokenizer,
309
+ append_eos,
310
+ use_src_lang_id,
311
+ )
312
+
313
+ @classmethod
314
+ def from_tsv(
315
+ cls,
316
+ root: str,
317
+ cfg: S2TJointDataConfig,
318
+ splits: str,
319
+ tgt_dict,
320
+ src_dict,
321
+ pre_tokenizer,
322
+ bpe_tokenizer,
323
+ src_pre_tokenizer,
324
+ src_bpe_tokenizer,
325
+ is_train_split: bool,
326
+ epoch: int,
327
+ seed: int,
328
+ append_eos: Optional[bool] = True,
329
+ use_src_lang_id: Optional[int] = 0,
330
+ ) -> SpeechToTextJointDataset:
331
+ datasets = [
332
+ cls._from_tsv(
333
+ root,
334
+ cfg,
335
+ split,
336
+ tgt_dict,
337
+ src_dict,
338
+ is_train_split,
339
+ pre_tokenizer,
340
+ bpe_tokenizer,
341
+ src_pre_tokenizer,
342
+ src_bpe_tokenizer,
343
+ append_eos=append_eos,
344
+ use_src_lang_id=use_src_lang_id,
345
+ )
346
+ for split in splits.split(",")
347
+ ]
348
+
349
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
350
+ # temperature-based sampling
351
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
352
+ datasets = [
353
+ ResamplingDataset(
354
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
355
+ )
356
+ for r, d in zip(size_ratios, datasets)
357
+ ]
358
+
359
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
fairseq/fairseq/data/audio/text_to_speech_dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.abs
7
+
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from fairseq.data import Dictionary
16
+ from fairseq.data import data_utils as fairseq_data_utils
17
+ from fairseq.data.audio.audio_utils import get_features_or_waveform
18
+ from fairseq.data.audio.speech_to_text_dataset import (
19
+ S2TDataConfig,
20
+ SpeechToTextDataset,
21
+ SpeechToTextDatasetCreator,
22
+ _collate_frames,
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class TextToSpeechDatasetItem(object):
28
+ index: int
29
+ source: torch.Tensor
30
+ target: Optional[torch.Tensor] = None
31
+ speaker_id: Optional[int] = None
32
+ duration: Optional[torch.Tensor] = None
33
+ pitch: Optional[torch.Tensor] = None
34
+ energy: Optional[torch.Tensor] = None
35
+
36
+
37
+ class TextToSpeechDataset(SpeechToTextDataset):
38
+ def __init__(
39
+ self,
40
+ split: str,
41
+ is_train_split: bool,
42
+ cfg: S2TDataConfig,
43
+ audio_paths: List[str],
44
+ n_frames: List[int],
45
+ src_texts: Optional[List[str]] = None,
46
+ tgt_texts: Optional[List[str]] = None,
47
+ speakers: Optional[List[str]] = None,
48
+ src_langs: Optional[List[str]] = None,
49
+ tgt_langs: Optional[List[str]] = None,
50
+ ids: Optional[List[str]] = None,
51
+ tgt_dict: Optional[Dictionary] = None,
52
+ pre_tokenizer=None,
53
+ bpe_tokenizer=None,
54
+ n_frames_per_step=1,
55
+ speaker_to_id=None,
56
+ durations: Optional[List[List[int]]] = None,
57
+ pitches: Optional[List[str]] = None,
58
+ energies: Optional[List[str]] = None,
59
+ ):
60
+ super(TextToSpeechDataset, self).__init__(
61
+ split,
62
+ is_train_split,
63
+ cfg,
64
+ audio_paths,
65
+ n_frames,
66
+ src_texts=src_texts,
67
+ tgt_texts=tgt_texts,
68
+ speakers=speakers,
69
+ src_langs=src_langs,
70
+ tgt_langs=tgt_langs,
71
+ ids=ids,
72
+ tgt_dict=tgt_dict,
73
+ pre_tokenizer=pre_tokenizer,
74
+ bpe_tokenizer=bpe_tokenizer,
75
+ n_frames_per_step=n_frames_per_step,
76
+ speaker_to_id=speaker_to_id,
77
+ )
78
+ self.durations = durations
79
+ self.pitches = pitches
80
+ self.energies = energies
81
+
82
+ def __getitem__(self, index: int) -> TextToSpeechDatasetItem:
83
+ s2t_item = super().__getitem__(index)
84
+
85
+ duration, pitch, energy = None, None, None
86
+ if self.durations is not None:
87
+ duration = torch.tensor(
88
+ self.durations[index] + [0], dtype=torch.long # pad 0 for EOS
89
+ )
90
+ if self.pitches is not None:
91
+ pitch = get_features_or_waveform(self.pitches[index])
92
+ pitch = torch.from_numpy(
93
+ np.concatenate((pitch, [0])) # pad 0 for EOS
94
+ ).float()
95
+ if self.energies is not None:
96
+ energy = get_features_or_waveform(self.energies[index])
97
+ energy = torch.from_numpy(
98
+ np.concatenate((energy, [0])) # pad 0 for EOS
99
+ ).float()
100
+ return TextToSpeechDatasetItem(
101
+ index=index,
102
+ source=s2t_item.source,
103
+ target=s2t_item.target,
104
+ speaker_id=s2t_item.speaker_id,
105
+ duration=duration,
106
+ pitch=pitch,
107
+ energy=energy,
108
+ )
109
+
110
+ def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
111
+ if len(samples) == 0:
112
+ return {}
113
+
114
+ src_lengths, order = torch.tensor(
115
+ [s.target.shape[0] for s in samples], dtype=torch.long
116
+ ).sort(descending=True)
117
+ id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
118
+ 0, order
119
+ )
120
+ feat = _collate_frames(
121
+ [s.source for s in samples], self.cfg.use_audio_input
122
+ ).index_select(0, order)
123
+ target_lengths = torch.tensor(
124
+ [s.source.shape[0] for s in samples], dtype=torch.long
125
+ ).index_select(0, order)
126
+
127
+ src_tokens = fairseq_data_utils.collate_tokens(
128
+ [s.target for s in samples],
129
+ self.tgt_dict.pad(),
130
+ self.tgt_dict.eos(),
131
+ left_pad=False,
132
+ move_eos_to_beginning=False,
133
+ ).index_select(0, order)
134
+
135
+ speaker = None
136
+ if self.speaker_to_id is not None:
137
+ speaker = (
138
+ torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
139
+ .index_select(0, order)
140
+ .view(-1, 1)
141
+ )
142
+
143
+ bsz, _, d = feat.size()
144
+ prev_output_tokens = torch.cat(
145
+ (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1
146
+ )
147
+
148
+ durations, pitches, energies = None, None, None
149
+ if self.durations is not None:
150
+ durations = fairseq_data_utils.collate_tokens(
151
+ [s.duration for s in samples], 0
152
+ ).index_select(0, order)
153
+ assert src_tokens.shape[1] == durations.shape[1]
154
+ if self.pitches is not None:
155
+ pitches = _collate_frames([s.pitch for s in samples], True)
156
+ pitches = pitches.index_select(0, order)
157
+ assert src_tokens.shape[1] == pitches.shape[1]
158
+ if self.energies is not None:
159
+ energies = _collate_frames([s.energy for s in samples], True)
160
+ energies = energies.index_select(0, order)
161
+ assert src_tokens.shape[1] == energies.shape[1]
162
+ src_texts = [self.tgt_dict.string(samples[i].target) for i in order]
163
+
164
+ return {
165
+ "id": id_,
166
+ "net_input": {
167
+ "src_tokens": src_tokens,
168
+ "src_lengths": src_lengths,
169
+ "prev_output_tokens": prev_output_tokens,
170
+ },
171
+ "speaker": speaker,
172
+ "target": feat,
173
+ "durations": durations,
174
+ "pitches": pitches,
175
+ "energies": energies,
176
+ "target_lengths": target_lengths,
177
+ "ntokens": sum(target_lengths).item(),
178
+ "nsentences": len(samples),
179
+ "src_texts": src_texts,
180
+ }
181
+
182
+
183
+ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
184
+ KEY_DURATION = "duration"
185
+ KEY_PITCH = "pitch"
186
+ KEY_ENERGY = "energy"
187
+
188
+ @classmethod
189
+ def _from_list(
190
+ cls,
191
+ split_name: str,
192
+ is_train_split,
193
+ samples: List[Dict],
194
+ cfg: S2TDataConfig,
195
+ tgt_dict,
196
+ pre_tokenizer,
197
+ bpe_tokenizer,
198
+ n_frames_per_step,
199
+ speaker_to_id,
200
+ multitask=None,
201
+ ) -> TextToSpeechDataset:
202
+ audio_root = Path(cfg.audio_root)
203
+ ids = [s[cls.KEY_ID] for s in samples]
204
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
205
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
206
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
207
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
208
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
209
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
210
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
211
+
212
+ durations = [s.get(cls.KEY_DURATION, None) for s in samples]
213
+ durations = [
214
+ None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
215
+ ]
216
+ durations = None if any(dd is None for dd in durations) else durations
217
+
218
+ pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
219
+ pitches = [
220
+ None if pp is None else (audio_root / pp).as_posix() for pp in pitches
221
+ ]
222
+ pitches = None if any(pp is None for pp in pitches) else pitches
223
+
224
+ energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
225
+ energies = [
226
+ None if ee is None else (audio_root / ee).as_posix() for ee in energies
227
+ ]
228
+ energies = None if any(ee is None for ee in energies) else energies
229
+
230
+ return TextToSpeechDataset(
231
+ split_name,
232
+ is_train_split,
233
+ cfg,
234
+ audio_paths,
235
+ n_frames,
236
+ src_texts,
237
+ tgt_texts,
238
+ speakers,
239
+ src_langs,
240
+ tgt_langs,
241
+ ids,
242
+ tgt_dict,
243
+ pre_tokenizer,
244
+ bpe_tokenizer,
245
+ n_frames_per_step,
246
+ speaker_to_id,
247
+ durations,
248
+ pitches,
249
+ energies,
250
+ )
fairseq/fairseq/data/audio/waveform_transforms/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fairseq.data.audio import (
3
+ AudioTransform,
4
+ CompositeAudioTransform,
5
+ import_transforms,
6
+ register_audio_transform,
7
+ )
8
+
9
+
10
+ class AudioWaveformTransform(AudioTransform):
11
+ pass
12
+
13
+
14
+ AUDIO_WAVEFORM_TRANSFORM_REGISTRY = {}
15
+ AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES = set()
16
+
17
+
18
+ def get_audio_waveform_transform(name):
19
+ return AUDIO_WAVEFORM_TRANSFORM_REGISTRY[name]
20
+
21
+
22
+ def register_audio_waveform_transform(name):
23
+ return register_audio_transform(
24
+ name,
25
+ AudioWaveformTransform,
26
+ AUDIO_WAVEFORM_TRANSFORM_REGISTRY,
27
+ AUDIO_WAVEFORM_TRANSFORM_CLASS_NAMES,
28
+ )
29
+
30
+
31
+ import_transforms(os.path.dirname(__file__), "waveform")
32
+
33
+
34
+ class CompositeAudioWaveformTransform(CompositeAudioTransform):
35
+ @classmethod
36
+ def from_config_dict(cls, config=None):
37
+ return super()._from_config_dict(
38
+ cls,
39
+ "waveform",
40
+ get_audio_waveform_transform,
41
+ CompositeAudioWaveformTransform,
42
+ config,
43
+ )
44
+
45
+ def __call__(self, x, sample_rate):
46
+ for t in self.transforms:
47
+ x, sample_rate = t(x, sample_rate)
48
+ return x, sample_rate
fairseq/fairseq/data/audio/waveform_transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.67 kB). View file
 
fairseq/fairseq/data/audio/waveform_transforms/__pycache__/noiseaugment.cpython-310.pyc ADDED
Binary file (6.29 kB). View file
 
fairseq/fairseq/data/audio/waveform_transforms/noiseaugment.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ from math import ceil
4
+
5
+ from fairseq.data.audio import rand_uniform
6
+ from fairseq.data.audio.waveform_transforms import (
7
+ AudioWaveformTransform,
8
+ register_audio_waveform_transform,
9
+ )
10
+
11
+ SNR_MIN = 5.0
12
+ SNR_MAX = 15.0
13
+ RATE = 0.25
14
+
15
+ NOISE_RATE = 1.0
16
+ NOISE_LEN_MEAN = 0.2
17
+ NOISE_LEN_STD = 0.05
18
+
19
+
20
+ class NoiseAugmentTransform(AudioWaveformTransform):
21
+ @classmethod
22
+ def from_config_dict(cls, config=None):
23
+ _config = {} if config is None else config
24
+ return cls(
25
+ _config.get("samples_path", None),
26
+ _config.get("snr_min", SNR_MIN),
27
+ _config.get("snr_max", SNR_MAX),
28
+ _config.get("rate", RATE),
29
+ )
30
+
31
+ def __init__(
32
+ self,
33
+ samples_path: str,
34
+ snr_min: float = SNR_MIN,
35
+ snr_max: float = SNR_MAX,
36
+ rate: float = RATE,
37
+ ):
38
+ # Sanity checks
39
+ assert (
40
+ samples_path
41
+ ), "need to provide path to audio samples for noise augmentation"
42
+ assert snr_max >= snr_min, f"empty signal-to-noise range ({snr_min}, {snr_max})"
43
+ assert rate >= 0 and rate <= 1, "rate should be a float between 0 to 1"
44
+
45
+ self.paths = list(Path(samples_path).glob("**/*.wav")) # load music
46
+ self.n_samples = len(self.paths)
47
+ assert self.n_samples > 0, f"no audio files found in {samples_path}"
48
+
49
+ self.snr_min = snr_min
50
+ self.snr_max = snr_max
51
+ self.rate = rate
52
+
53
+ def __repr__(self):
54
+ return (
55
+ self.__class__.__name__
56
+ + "("
57
+ + ", ".join(
58
+ [
59
+ f"n_samples={self.n_samples}",
60
+ f"snr={self.snr_min}-{self.snr_max}dB",
61
+ f"rate={self.rate}",
62
+ ]
63
+ )
64
+ + ")"
65
+ )
66
+
67
+ def pick_sample(self, goal_shape, always_2d=False, use_sample_rate=None):
68
+ from fairseq.data.audio.audio_utils import get_waveform
69
+
70
+ path = self.paths[np.random.randint(0, self.n_samples)]
71
+ sample = get_waveform(
72
+ path, always_2d=always_2d, output_sample_rate=use_sample_rate
73
+ )[0]
74
+
75
+ # Check dimensions match, else silently skip adding noise to sample
76
+ # NOTE: SHOULD THIS QUIT WITH AN ERROR?
77
+ is_2d = len(goal_shape) == 2
78
+ if len(goal_shape) != sample.ndim or (
79
+ is_2d and goal_shape[0] != sample.shape[0]
80
+ ):
81
+ return np.zeros(goal_shape)
82
+
83
+ # Cut/repeat sample to size
84
+ len_dim = len(goal_shape) - 1
85
+ n_repeat = ceil(goal_shape[len_dim] / sample.shape[len_dim])
86
+ repeated = np.tile(sample, [1, n_repeat] if is_2d else n_repeat)
87
+ start = np.random.randint(0, repeated.shape[len_dim] - goal_shape[len_dim] + 1)
88
+ return (
89
+ repeated[:, start : start + goal_shape[len_dim]]
90
+ if is_2d
91
+ else repeated[start : start + goal_shape[len_dim]]
92
+ )
93
+
94
+ def _mix(self, source, noise, snr):
95
+ get_power = lambda x: np.mean(x**2)
96
+ if get_power(noise):
97
+ scl = np.sqrt(
98
+ get_power(source) / (np.power(10, snr / 10) * get_power(noise))
99
+ )
100
+ else:
101
+ scl = 0
102
+ return 1 * source + scl * noise
103
+
104
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
105
+ return self.pick_sample(goal_shape, always_2d, use_sample_rate)
106
+
107
+ def __call__(self, source, sample_rate):
108
+ if np.random.random() > self.rate:
109
+ return source, sample_rate
110
+
111
+ noise = self._get_noise(
112
+ source.shape, always_2d=True, use_sample_rate=sample_rate
113
+ )
114
+
115
+ return (
116
+ self._mix(source, noise, rand_uniform(self.snr_min, self.snr_max)),
117
+ sample_rate,
118
+ )
119
+
120
+
121
+ @register_audio_waveform_transform("musicaugment")
122
+ class MusicAugmentTransform(NoiseAugmentTransform):
123
+ pass
124
+
125
+
126
+ @register_audio_waveform_transform("backgroundnoiseaugment")
127
+ class BackgroundNoiseAugmentTransform(NoiseAugmentTransform):
128
+ pass
129
+
130
+
131
+ @register_audio_waveform_transform("babbleaugment")
132
+ class BabbleAugmentTransform(NoiseAugmentTransform):
133
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
134
+ for i in range(np.random.randint(3, 8)):
135
+ speech = self.pick_sample(goal_shape, always_2d, use_sample_rate)
136
+ if i == 0:
137
+ agg_noise = speech
138
+ else: # SNR scaled by i (how many noise signals already in agg_noise)
139
+ agg_noise = self._mix(agg_noise, speech, i)
140
+ return agg_noise
141
+
142
+
143
+ @register_audio_waveform_transform("sporadicnoiseaugment")
144
+ class SporadicNoiseAugmentTransform(NoiseAugmentTransform):
145
+ @classmethod
146
+ def from_config_dict(cls, config=None):
147
+ _config = {} if config is None else config
148
+ return cls(
149
+ _config.get("samples_path", None),
150
+ _config.get("snr_min", SNR_MIN),
151
+ _config.get("snr_max", SNR_MAX),
152
+ _config.get("rate", RATE),
153
+ _config.get("noise_rate", NOISE_RATE),
154
+ _config.get("noise_len_mean", NOISE_LEN_MEAN),
155
+ _config.get("noise_len_std", NOISE_LEN_STD),
156
+ )
157
+
158
+ def __init__(
159
+ self,
160
+ samples_path: str,
161
+ snr_min: float = SNR_MIN,
162
+ snr_max: float = SNR_MAX,
163
+ rate: float = RATE,
164
+ noise_rate: float = NOISE_RATE, # noises per second
165
+ noise_len_mean: float = NOISE_LEN_MEAN, # length of noises in seconds
166
+ noise_len_std: float = NOISE_LEN_STD,
167
+ ):
168
+ super().__init__(samples_path, snr_min, snr_max, rate)
169
+ self.noise_rate = noise_rate
170
+ self.noise_len_mean = noise_len_mean
171
+ self.noise_len_std = noise_len_std
172
+
173
+ def _get_noise(self, goal_shape, always_2d=False, use_sample_rate=None):
174
+ agg_noise = np.zeros(goal_shape)
175
+ len_dim = len(goal_shape) - 1
176
+ is_2d = len(goal_shape) == 2
177
+
178
+ n_noises = round(self.noise_rate * goal_shape[len_dim] / use_sample_rate)
179
+ start_pointers = [
180
+ round(rand_uniform(0, goal_shape[len_dim])) for _ in range(n_noises)
181
+ ]
182
+
183
+ for start_pointer in start_pointers:
184
+ noise_shape = list(goal_shape)
185
+ len_seconds = np.random.normal(self.noise_len_mean, self.noise_len_std)
186
+ noise_shape[len_dim] = round(max(0, len_seconds) * use_sample_rate)
187
+ end_pointer = start_pointer + noise_shape[len_dim]
188
+ if end_pointer >= goal_shape[len_dim]:
189
+ continue
190
+
191
+ noise = self.pick_sample(noise_shape, always_2d, use_sample_rate)
192
+ if is_2d:
193
+ agg_noise[:, start_pointer:end_pointer] = (
194
+ agg_noise[:, start_pointer:end_pointer] + noise
195
+ )
196
+ else:
197
+ agg_noise[start_pointer:end_pointer] = (
198
+ agg_noise[start_pointer:end_pointer] + noise
199
+ )
200
+
201
+ return agg_noise
fairseq/fairseq/data/data_utils_fast.cpp ADDED
The diff for this file is too large to render. See raw diff
 
fairseq/fairseq/data/encoders/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import importlib
8
+ import os
9
+
10
+ from fairseq import registry
11
+
12
+
13
+ build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
14
+ "--tokenizer",
15
+ default=None,
16
+ )
17
+
18
+
19
+ build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
20
+ "--bpe",
21
+ default=None,
22
+ )
23
+
24
+
25
+ # automatically import any Python files in the encoders/ directory
26
+ for file in sorted(os.listdir(os.path.dirname(__file__))):
27
+ if file.endswith(".py") and not file.startswith("_"):
28
+ module = file[: file.find(".py")]
29
+ importlib.import_module("fairseq.data.encoders." + module)
fairseq/fairseq/data/encoders/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (721 Bytes). View file
 
fairseq/fairseq/data/encoders/__pycache__/byte_bpe.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/byte_utils.cpython-310.pyc ADDED
Binary file (2.21 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/bytes.cpython-310.pyc ADDED
Binary file (1.26 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/characters.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/fastbpe.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/gpt2_bpe.cpython-310.pyc ADDED
Binary file (2.15 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/hf_bert_bpe.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/moses_tokenizer.cpython-310.pyc ADDED
Binary file (1.97 kB). View file
 
fairseq/fairseq/data/encoders/__pycache__/sentencepiece_bpe.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
fairseq/fairseq/data/encoders/byte_bpe.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ from fairseq import file_utils
10
+ from fairseq.data.encoders import register_bpe
11
+ from fairseq.data.encoders.byte_utils import (
12
+ SPACE,
13
+ SPACE_ESCAPE,
14
+ byte_encode,
15
+ smart_byte_decode,
16
+ )
17
+ from fairseq.dataclass import FairseqDataclass
18
+
19
+
20
+ @dataclass
21
+ class ByteBpeConfig(FairseqDataclass):
22
+ sentencepiece_model_path: str = field(
23
+ default="???", metadata={"help": "path to sentencepiece model"}
24
+ )
25
+
26
+
27
+ @register_bpe("byte_bpe", dataclass=ByteBpeConfig)
28
+ class ByteBPE(object):
29
+ def __init__(self, cfg):
30
+ vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
31
+ try:
32
+ import sentencepiece as spm
33
+
34
+ self.sp = spm.SentencePieceProcessor()
35
+ self.sp.Load(vocab)
36
+ except ImportError:
37
+ raise ImportError(
38
+ "Please install sentencepiece with: pip install sentencepiece"
39
+ )
40
+
41
+ def encode(self, x: str) -> str:
42
+ byte_encoded = byte_encode(x)
43
+ return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
44
+
45
+ @staticmethod
46
+ def decode(x: str) -> str:
47
+ unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
48
+ return smart_byte_decode(unescaped)