jiaqili3 commited on
Commit
d949a00
·
1 Parent(s): addb7e5
models/tts/valle_v2.1/base_trainer.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import random
9
+ import shutil
10
+ import time
11
+ from abc import abstractmethod
12
+ from pathlib import Path
13
+ import math
14
+ import accelerate
15
+ import json5
16
+ import numpy as np
17
+ import torch
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from models.base.base_sampler import build_samplers
24
+ from optimizer.optimizers import NoamLR
25
+
26
+
27
+ class MainProcessLogger:
28
+ def __init__(self, is_main_process=True, name=None, **kwargs):
29
+ import logging
30
+
31
+ if name is None:
32
+ logger = logging.getLogger(__name__)
33
+ else:
34
+ logger = logging.getLogger(name)
35
+ self.logger = logger
36
+ self.is_main_process = is_main_process
37
+
38
+ def info(self, msg):
39
+ if self.is_main_process:
40
+ print(msg)
41
+ # self.logger.info(msg)
42
+
43
+ def debug(self, msg):
44
+ if self.is_main_process:
45
+ print(msg)
46
+ # self.logger.debug(msg)
47
+
48
+ def warning(self, msg):
49
+ if self.is_main_process:
50
+ print(msg)
51
+ # self.logger.warning(msg)
52
+
53
+
54
+ class BaseTrainer(object):
55
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
56
+
57
+ def __init__(self, args=None, cfg=None):
58
+ super().__init__()
59
+
60
+ self.args = args
61
+ self.cfg = cfg
62
+
63
+ cfg.exp_name = args.exp_name
64
+
65
+ # init with accelerate
66
+ self._init_accelerator()
67
+ self.accelerator.wait_for_everyone()
68
+
69
+ # Use accelerate logger for distributed training
70
+ with self.accelerator.main_process_first():
71
+ self.logger = MainProcessLogger(
72
+ self.accelerator.is_main_process,
73
+ name=args.exp_name,
74
+ log_level=args.log_level,
75
+ )
76
+
77
+ # Log some info
78
+ self.logger.info("=" * 56)
79
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
80
+ self.logger.info("=" * 56)
81
+ self.logger.info("\n")
82
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
83
+ self.logger.info(f"Experiment name: {args.exp_name}")
84
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
85
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
86
+ if self.accelerator.is_main_process:
87
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
88
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
89
+
90
+ # init counts
91
+ self.batch_count: int = 0
92
+ self.step: int = 0
93
+ self.epoch: int = 0
94
+ self.max_epoch = (
95
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
96
+ )
97
+ self.logger.info(
98
+ "Max epoch: {}".format(
99
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
100
+ )
101
+ )
102
+
103
+ # Check values
104
+ if self.accelerator.is_main_process:
105
+ self.__check_basic_configs()
106
+ # Set runtime configs
107
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
108
+ self.checkpoints_path = [
109
+ [] for _ in range(len(self.save_checkpoint_stride))
110
+ ]
111
+ self.keep_last = [
112
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
113
+ ]
114
+ self.run_eval = self.cfg.train.run_eval
115
+
116
+ # set random seed
117
+ with self.accelerator.main_process_first():
118
+ start = time.monotonic_ns()
119
+ self._set_random_seed(args.seed)
120
+ end = time.monotonic_ns()
121
+ self.logger.debug(
122
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
123
+ )
124
+ self.logger.debug(f"Random seed: {args.seed}")
125
+
126
+ # setup data_loader
127
+ with self.accelerator.main_process_first():
128
+ self.logger.info("Building dataset...")
129
+ start = time.monotonic_ns()
130
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
131
+ end = time.monotonic_ns()
132
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
133
+
134
+ # setup model
135
+ with self.accelerator.main_process_first():
136
+ self.logger.info("Building model...")
137
+ start = time.monotonic_ns()
138
+ self.model = self._build_model()
139
+ end = time.monotonic_ns()
140
+ self.logger.debug(self.model)
141
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
142
+ self.logger.info(
143
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
144
+ )
145
+ # optimizer & scheduler
146
+ with self.accelerator.main_process_first():
147
+ self.logger.info("Building optimizer and scheduler...")
148
+ start = time.monotonic_ns()
149
+ self.optimizer = self._build_optimizer()
150
+ self.scheduler = self._build_scheduler()
151
+ end = time.monotonic_ns()
152
+ self.logger.info(
153
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
154
+ )
155
+
156
+ # accelerate prepare
157
+ self.logger.info("Initializing accelerate...")
158
+ start = time.monotonic_ns()
159
+ self._accelerator_prepare()
160
+ end = time.monotonic_ns()
161
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
162
+
163
+ # create criterion
164
+ with self.accelerator.main_process_first():
165
+ self.logger.info("Building criterion...")
166
+ start = time.monotonic_ns()
167
+ self.criterion = self._build_criterion()
168
+ end = time.monotonic_ns()
169
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
170
+
171
+ # Resume or Finetune
172
+ with self.accelerator.main_process_first():
173
+ if args.resume:
174
+ if args.resume_from_ckpt_path == "":
175
+ ## Automatically resume according to the current exprimental name
176
+ self.logger.info(
177
+ "Automatically resuming from latest checkpoint in {}...".format(
178
+ self.checkpoint_dir
179
+ )
180
+ )
181
+ start = time.monotonic_ns()
182
+ ckpt_path = self._load_model(
183
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
184
+ )
185
+ end = time.monotonic_ns()
186
+ self.logger.info(
187
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
188
+ )
189
+ else:
190
+ ## Resume from the given checkpoint path
191
+ if not os.path.exists(args.resume_from_ckpt_path):
192
+ raise ValueError(
193
+ "[Error] The resumed checkpoint path {} don't exist.".format(
194
+ args.resume_from_ckpt_path
195
+ )
196
+ )
197
+ self.logger.info(
198
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
199
+ )
200
+ start = time.monotonic_ns()
201
+ ckpt_path = self._load_model(
202
+ checkpoint_path=args.resume_from_ckpt_path,
203
+ resume_type=args.resume_type,
204
+ )
205
+ end = time.monotonic_ns()
206
+ self.logger.info(
207
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
208
+ )
209
+
210
+ # save config file path
211
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
212
+
213
+ def _accelerator_prepare(self):
214
+ (
215
+ self.train_dataloader,
216
+ self.valid_dataloader,
217
+ self.model,
218
+ self.optimizer,
219
+ self.scheduler,
220
+ ) = self.accelerator.prepare(
221
+ self.train_dataloader,
222
+ self.valid_dataloader,
223
+ self.model,
224
+ self.optimizer,
225
+ self.scheduler,
226
+ )
227
+
228
+ ### Following are abstract methods that should be implemented in child classes ###
229
+ @abstractmethod
230
+ def _build_dataset(self):
231
+ r"""Build dataset for model training/validating/evaluating."""
232
+ pass
233
+
234
+ @staticmethod
235
+ @abstractmethod
236
+ def _build_criterion():
237
+ r"""Build criterion function for model loss calculation."""
238
+ pass
239
+
240
+ @abstractmethod
241
+ def _build_model(self):
242
+ r"""Build model for training/validating/evaluating."""
243
+ pass
244
+
245
+ @abstractmethod
246
+ def _forward_step(self, batch):
247
+ r"""One forward step of the neural network. This abstract method is trying to
248
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
249
+ However, for special case that using different forward step pattern for
250
+ training and validating, you could just override this method with ``pass`` and
251
+ implement ``_train_step`` and ``_valid_step`` separately.
252
+ """
253
+ pass
254
+
255
+ def save_checkpoint(self):
256
+ if self.accelerator.is_main_process:
257
+ keep_last = self.keep_last[0]
258
+ # 读取self.checkpoint_dir所有的folder
259
+ all_ckpts = os.listdir(self.checkpoint_dir)
260
+ all_ckpts = filter(lambda x: x.startswith("epoch"), all_ckpts)
261
+ all_ckpts = list(all_ckpts)
262
+ if len(all_ckpts) > keep_last:
263
+ # 只保留keep_last个的folder in self.checkpoint_dir, sort by step "epoch-{:04d}_step-{:07d}_loss-{:.6f}"
264
+ all_ckpts = sorted(
265
+ all_ckpts, key=lambda x: int(x.split("_")[1].split("-")[1])
266
+ )
267
+ for ckpt in all_ckpts[:-keep_last]:
268
+ shutil.rmtree(os.path.join(self.checkpoint_dir, ckpt))
269
+ checkpoint_filename = "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
270
+ self.epoch, self.step, self.current_loss
271
+ )
272
+ path = os.path.join(self.checkpoint_dir, checkpoint_filename)
273
+ self.logger.info("Saving state to {}...".format(path))
274
+ self.accelerator.save_state(path)
275
+ self.logger.info("Finished saving state.")
276
+
277
+ @abstractmethod
278
+ def _save_auxiliary_states(self):
279
+ r"""To save some auxiliary states when saving model's ckpt"""
280
+ pass
281
+
282
+ def echo_log(self, losses, mode="Training"):
283
+ message = [
284
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
285
+ mode, self.epoch + 1, self.step, self.time_window.average
286
+ )
287
+ ]
288
+
289
+ for key in sorted(losses.keys()):
290
+ if isinstance(losses[key], dict):
291
+ for k, v in losses[key].items():
292
+ message.append(
293
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
294
+ )
295
+ else:
296
+ message.append(
297
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
298
+ )
299
+ self.logger.info(", ".join(message))
300
+
301
+ ### Abstract methods end ###
302
+
303
+ ### THIS IS MAIN ENTRY ###
304
+ def train_loop(self):
305
+ r"""Training loop. The public entry of training process."""
306
+ # Wait everyone to prepare before we move on
307
+ self.accelerator.wait_for_everyone()
308
+ # dump config file
309
+ if self.accelerator.is_main_process:
310
+ self.__dump_cfg(self.config_save_path)
311
+ self.model.train()
312
+ self.optimizer.zero_grad()
313
+ while self.epoch < self.max_epoch:
314
+ self.logger.info("\n")
315
+ self.logger.info("-" * 32)
316
+ self.logger.info("Epoch {}: ".format(self.epoch))
317
+
318
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
319
+ ### It's inconvenient for the model with multiple losses
320
+ # Do training & validating epoch
321
+ train_loss = self._train_epoch()
322
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
323
+ valid_loss = self._valid_epoch()
324
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
325
+ self.accelerator.log(
326
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
327
+ step=self.epoch,
328
+ )
329
+
330
+ self.accelerator.wait_for_everyone()
331
+
332
+ # Update info for each epoch
333
+ self.epoch += 1
334
+
335
+ # Finish training and save final checkpoint
336
+ self.accelerator.wait_for_everyone()
337
+ if self.accelerator.is_main_process:
338
+ self.accelerator.save_state(
339
+ os.path.join(
340
+ self.checkpoint_dir,
341
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
342
+ self.epoch, self.step, valid_loss
343
+ ),
344
+ )
345
+ )
346
+ self._save_auxiliary_states()
347
+
348
+ self.accelerator.end_training()
349
+
350
+ def get_lr(self, it):
351
+ # 1) linear warmup for warmup_iters steps
352
+ if it < self.cfg.train.scheduler.warmup_steps:
353
+ return self.cfg.train.adamw.lr * it / self.cfg.train.scheduler.warmup_steps
354
+ # 2) if it > lr_decay_iters, return min learning rate
355
+ if it > self.cfg.train.scheduler.total_steps:
356
+ return self.cfg.train.scheduler.min_lr
357
+ # 3) in between, use cosine decay down to min learning rate
358
+ decay_ratio = (it - self.cfg.train.scheduler.warmup_steps) / (
359
+ self.cfg.train.scheduler.total_steps - self.cfg.train.scheduler.warmup_steps
360
+ )
361
+ assert 0 <= decay_ratio <= 1
362
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
363
+ return self.cfg.train.scheduler.min_lr + coeff * (
364
+ self.cfg.train.adamw.lr - self.cfg.train.scheduler.min_lr
365
+ )
366
+
367
+ ### Following are methods that can be used directly in child classes ###
368
+ def _train_epoch(self):
369
+ r"""Training epoch. Should return average loss of a batch (sample) over
370
+ one epoch. See ``train_loop`` for usage.
371
+ """
372
+ self.model.train()
373
+ epoch_sum_loss: float = 0.0
374
+ ema_loss = None
375
+
376
+ # profiler
377
+ start_this_step_time = time.time()
378
+ finish_last_step_time = time.time()
379
+
380
+ for batch in tqdm(
381
+ self.train_dataloader,
382
+ desc=f"Training Epoch {self.epoch}",
383
+ unit="batch",
384
+ colour="GREEN",
385
+ leave=False,
386
+ dynamic_ncols=True,
387
+ smoothing=0.04,
388
+ disable=not self.accelerator.is_main_process,
389
+ ):
390
+ assert batch is not None
391
+
392
+ # start_this_step_time = time.time()
393
+ # print(f'load batch took: {start_this_step_time - finish_last_step_time:.6f}s')
394
+
395
+ # update learning rate
396
+ lr = self.get_lr(self.step)
397
+ for param_group in self.optimizer.param_groups:
398
+ param_group["lr"] = lr
399
+
400
+ # Do training step and BP
401
+ with self.accelerator.accumulate(self.model):
402
+ loss = self._train_step(batch)
403
+ self.current_loss = loss.item()
404
+ ema_loss = (
405
+ 0.99 * ema_loss + 0.01 * self.current_loss
406
+ if ema_loss is not None
407
+ else self.current_loss
408
+ )
409
+ self.accelerator.backward(loss)
410
+ if self.accelerator.sync_gradients:
411
+ self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
412
+ self.optimizer.step()
413
+ self.optimizer.zero_grad()
414
+ self.batch_count += 1
415
+
416
+ # if self.accelerator.is_main_process:
417
+ # print(self.current_loss)
418
+
419
+ if self.accelerator.sync_gradients:
420
+ if self.step % self.cfg.train.save_checkpoint_stride[0] == 0:
421
+ self.accelerator.wait_for_everyone()
422
+ if self.accelerator.is_main_process:
423
+ try:
424
+ self.save_checkpoint()
425
+ except:
426
+ self.logger.info("Failed to save checkpoint, resuming...")
427
+ if self.accelerator.is_main_process:
428
+ if self.step % 100 == 0:
429
+ self.logger.info(f"EMA Loss: {ema_loss:.6f}")
430
+ self.accelerator.log(
431
+ {
432
+ "Step/Train Loss": loss,
433
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
434
+ },
435
+ step=self.step,
436
+ )
437
+ epoch_sum_loss += loss.item()
438
+ self.step += 1
439
+
440
+ # finish_last_step_time = time.time()
441
+ # print(f'load took: {finish_last_step_time - start_this_step_time:.6f}s')
442
+ return (
443
+ epoch_sum_loss
444
+ / len(self.train_dataloader)
445
+ * self.cfg.train.gradient_accumulation_step
446
+ )
447
+
448
+ @torch.inference_mode()
449
+ def _valid_epoch(self):
450
+ r"""Testing epoch. Should return average loss of a batch (sample) over
451
+ one epoch. See ``train_loop`` for usage.
452
+ """
453
+ self.model.eval()
454
+ epoch_sum_loss = 0.0
455
+ for batch in tqdm(
456
+ self.valid_dataloader,
457
+ desc=f"Validating Epoch {self.epoch}",
458
+ unit="batch",
459
+ colour="GREEN",
460
+ leave=False,
461
+ dynamic_ncols=True,
462
+ smoothing=0.04,
463
+ disable=not self.accelerator.is_main_process,
464
+ ):
465
+ batch_loss = self._valid_step(batch)
466
+ epoch_sum_loss += batch_loss.item()
467
+
468
+ return epoch_sum_loss / len(self.valid_dataloader)
469
+
470
+ def _train_step(self, batch):
471
+ r"""Training forward step. Should return average loss of a sample over
472
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
473
+ See ``_train_epoch`` for usage.
474
+ """
475
+ return self._forward_step(batch)
476
+
477
+ @torch.inference_mode()
478
+ def _valid_step(self, batch):
479
+ r"""Testing forward step. Should return average loss of a sample over
480
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
481
+ See ``_test_epoch`` for usage.
482
+ """
483
+ return self._forward_step(batch)
484
+
485
+ def _load_model(
486
+ self,
487
+ checkpoint_dir: str = None,
488
+ checkpoint_path: str = None,
489
+ resume_type: str = "",
490
+ ):
491
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
492
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
493
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
494
+ method after** ``accelerator.prepare()``.
495
+ """
496
+ if checkpoint_path is None:
497
+ try:
498
+ all_ckpts = os.listdir(checkpoint_dir)
499
+ all_ckpts = filter(lambda x: x.startswith("epoch"), all_ckpts)
500
+ ls = list(all_ckpts)
501
+ ls = [os.path.join(checkpoint_dir, i) for i in ls]
502
+ ls.sort(
503
+ key=lambda x: int(x.split("_")[-2].split("-")[-1]), reverse=True
504
+ )
505
+ checkpoint_path = ls[0]
506
+ self.logger.info("Resume from {}".format(checkpoint_path))
507
+ except Exception as e:
508
+ print(
509
+ "Failed to load checkpoint from {}, starting FROM SCRATCH...".format(
510
+ checkpoint_dir
511
+ )
512
+ )
513
+ return None
514
+
515
+ if resume_type in ["resume", ""]:
516
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
517
+ self.accelerator.load_state(input_dir=checkpoint_path)
518
+
519
+ # set epoch and step
520
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
521
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
522
+
523
+ elif resume_type == "finetune":
524
+ # Load only the model weights
525
+ accelerate.load_checkpoint_and_dispatch(
526
+ self.accelerator.unwrap_model(self.model),
527
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
528
+ )
529
+ self.logger.info("Load model weights for finetune...")
530
+
531
+ else:
532
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
533
+
534
+ return checkpoint_path
535
+
536
+ # TODO: LEGACY CODE
537
+ def _build_dataloader(self):
538
+ Dataset, Collator = self._build_dataset()
539
+
540
+ # build dataset instance for each dataset and combine them by ConcatDataset
541
+ datasets_list = []
542
+ for dataset in self.cfg.dataset:
543
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
544
+ datasets_list.append(subdataset)
545
+ train_dataset = ConcatDataset(datasets_list)
546
+ train_collate = Collator(self.cfg)
547
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
548
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
549
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
550
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
551
+ train_loader = DataLoader(
552
+ train_dataset,
553
+ collate_fn=train_collate,
554
+ batch_sampler=batch_sampler,
555
+ num_workers=self.cfg.train.dataloader.num_worker,
556
+ pin_memory=self.cfg.train.dataloader.pin_memory,
557
+ )
558
+
559
+ # Build valid dataloader
560
+ datasets_list = []
561
+ for dataset in self.cfg.dataset:
562
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
563
+ datasets_list.append(subdataset)
564
+ valid_dataset = ConcatDataset(datasets_list)
565
+ valid_collate = Collator(self.cfg)
566
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
567
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
568
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
569
+ valid_loader = DataLoader(
570
+ valid_dataset,
571
+ collate_fn=valid_collate,
572
+ batch_sampler=batch_sampler,
573
+ num_workers=self.cfg.train.dataloader.num_worker,
574
+ pin_memory=self.cfg.train.dataloader.pin_memory,
575
+ )
576
+ return train_loader, valid_loader
577
+
578
+ @staticmethod
579
+ def _set_random_seed(seed):
580
+ r"""Set random seed for all possible random modules."""
581
+ random.seed(seed)
582
+ np.random.seed(seed)
583
+ torch.random.manual_seed(seed)
584
+
585
+ def _check_nan(self, loss, y_pred, y_gt):
586
+ if torch.any(torch.isnan(loss)):
587
+ self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
588
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
589
+ if torch.any(torch.isnan(y_pred)):
590
+ self.logger.error(
591
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
592
+ )
593
+ else:
594
+ self.logger.debug(
595
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
596
+ )
597
+ if torch.any(torch.isnan(y_gt)):
598
+ self.logger.error(
599
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
600
+ )
601
+ else:
602
+ self.logger.debug(
603
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
604
+ )
605
+ if torch.any(torch.isnan(y_pred)):
606
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
607
+ else:
608
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
609
+ if torch.any(torch.isnan(y_gt)):
610
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
611
+ else:
612
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
613
+
614
+ # TODO: still OK to save tracking?
615
+ self.accelerator.end_training()
616
+ raise RuntimeError("Loss has Nan! See log for more info.")
617
+
618
+ ### Protected methods end ###
619
+
620
+ ## Following are private methods ##
621
+ ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
622
+ def _build_optimizer(self):
623
+ r"""Build optimizer for model."""
624
+ # Make case-insensitive matching
625
+ if self.cfg.train.optimizer.lower() == "adadelta":
626
+ optimizer = torch.optim.Adadelta(
627
+ self.model.parameters(), **self.cfg.train.adadelta
628
+ )
629
+ self.logger.info("Using Adadelta optimizer.")
630
+ elif self.cfg.train.optimizer.lower() == "adagrad":
631
+ optimizer = torch.optim.Adagrad(
632
+ self.model.parameters(), **self.cfg.train.adagrad
633
+ )
634
+ self.logger.info("Using Adagrad optimizer.")
635
+ elif self.cfg.train.optimizer.lower() == "adam":
636
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
637
+ self.logger.info("Using Adam optimizer.")
638
+ elif self.cfg.train.optimizer.lower() == "adamw":
639
+ optimizer = torch.optim.AdamW(
640
+ self.model.parameters(), **self.cfg.train.adamw
641
+ )
642
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
643
+ optimizer = torch.optim.SparseAdam(
644
+ self.model.parameters(), **self.cfg.train.sparseadam
645
+ )
646
+ elif self.cfg.train.optimizer.lower() == "adamax":
647
+ optimizer = torch.optim.Adamax(
648
+ self.model.parameters(), **self.cfg.train.adamax
649
+ )
650
+ elif self.cfg.train.optimizer.lower() == "asgd":
651
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
652
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
653
+ optimizer = torch.optim.LBFGS(
654
+ self.model.parameters(), **self.cfg.train.lbfgs
655
+ )
656
+ elif self.cfg.train.optimizer.lower() == "nadam":
657
+ optimizer = torch.optim.NAdam(
658
+ self.model.parameters(), **self.cfg.train.nadam
659
+ )
660
+ elif self.cfg.train.optimizer.lower() == "radam":
661
+ optimizer = torch.optim.RAdam(
662
+ self.model.parameters(), **self.cfg.train.radam
663
+ )
664
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
665
+ optimizer = torch.optim.RMSprop(
666
+ self.model.parameters(), **self.cfg.train.rmsprop
667
+ )
668
+ elif self.cfg.train.optimizer.lower() == "rprop":
669
+ optimizer = torch.optim.Rprop(
670
+ self.model.parameters(), **self.cfg.train.rprop
671
+ )
672
+ elif self.cfg.train.optimizer.lower() == "sgd":
673
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
674
+ else:
675
+ raise NotImplementedError(
676
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
677
+ )
678
+ return optimizer
679
+
680
+ def _build_scheduler(self):
681
+ r"""Build scheduler for optimizer."""
682
+ # Make case-insensitive matching
683
+ if self.cfg.train.scheduler.lower() == "lambdalr":
684
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
685
+ self.optimizer, **self.cfg.train.lambdalr
686
+ )
687
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
688
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
689
+ self.optimizer, **self.cfg.train.multiplicativelr
690
+ )
691
+ elif self.cfg.train.scheduler.lower() == "steplr":
692
+ scheduler = torch.optim.lr_scheduler.StepLR(
693
+ self.optimizer, **self.cfg.train.steplr
694
+ )
695
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
696
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
697
+ self.optimizer, **self.cfg.train.multisteplr
698
+ )
699
+ elif self.cfg.train.scheduler.lower() == "constantlr":
700
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
701
+ self.optimizer, **self.cfg.train.constantlr
702
+ )
703
+ elif self.cfg.train.scheduler.lower() == "linearlr":
704
+ scheduler = torch.optim.lr_scheduler.LinearLR(
705
+ self.optimizer, **self.cfg.train.linearlr
706
+ )
707
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
708
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
709
+ self.optimizer, **self.cfg.train.exponentiallr
710
+ )
711
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
712
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
713
+ self.optimizer, **self.cfg.train.polynomiallr
714
+ )
715
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
716
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
717
+ self.optimizer, **self.cfg.train.cosineannealinglr
718
+ )
719
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
720
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
721
+ self.optimizer, **self.cfg.train.sequentiallr
722
+ )
723
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
724
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
725
+ self.optimizer, **self.cfg.train.reducelronplateau
726
+ )
727
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
728
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
729
+ self.optimizer, **self.cfg.train.cycliclr
730
+ )
731
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
732
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
733
+ self.optimizer, **self.cfg.train.onecyclelr
734
+ )
735
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
736
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
737
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
738
+ )
739
+ elif self.cfg.train.scheduler.lower() == "noamlr":
740
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
741
+ else:
742
+ raise NotImplementedError(
743
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
744
+ )
745
+ return scheduler
746
+
747
+ def _init_accelerator(self):
748
+ self.exp_dir = os.path.join(
749
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
750
+ )
751
+ project_config = ProjectConfiguration(
752
+ project_dir=self.exp_dir,
753
+ logging_dir=os.path.join(self.exp_dir, "log"),
754
+ )
755
+ from accelerate import DistributedDataParallelKwargs
756
+
757
+ kwargs = DistributedDataParallelKwargs(
758
+ find_unused_parameters=self.cfg.train.find_unused_parameters
759
+ )
760
+
761
+ self.accelerator = accelerate.Accelerator(
762
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
763
+ log_with=self.cfg.train.tracker,
764
+ project_config=project_config,
765
+ kwargs_handlers=[kwargs],
766
+ )
767
+ if self.accelerator.is_main_process:
768
+ os.makedirs(project_config.project_dir, exist_ok=True)
769
+ os.makedirs(project_config.logging_dir, exist_ok=True)
770
+ with self.accelerator.main_process_first():
771
+ self.accelerator.init_trackers(self.args.exp_name)
772
+
773
+ def __check_basic_configs(self):
774
+ if self.cfg.train.gradient_accumulation_step <= 0:
775
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
776
+ self.logger.error(
777
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
778
+ )
779
+ self.accelerator.end_training()
780
+ raise ValueError(
781
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
782
+ )
783
+ # TODO: check other values
784
+
785
+ @staticmethod
786
+ def __count_parameters(model):
787
+ model_param = 0.0
788
+ if isinstance(model, dict):
789
+ for key, value in model.items():
790
+ model_param += sum(p.numel() for p in model[key].parameters())
791
+ else:
792
+ model_param = sum(p.numel() for p in model.parameters())
793
+ return model_param
794
+
795
+ def __dump_cfg(self, path):
796
+ os.makedirs(os.path.dirname(path), exist_ok=True)
797
+ json5.dump(
798
+ self.cfg,
799
+ open(path, "w"),
800
+ indent=4,
801
+ sort_keys=True,
802
+ ensure_ascii=False,
803
+ quote_keys=True,
804
+ )
805
+
806
+ @torch.inference_mode()
807
+ def test_loop(self):
808
+ pass
809
+
810
+ ### Private methods end ###
models/tts/valle_v2.1/cfg/base.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_folder: /gluster-tts/emilia
2
+ w2v_path: /gluster-tts/jiaqi_repos/w2v-bert-2
3
+ ckpt_root_path: /gluster-tts/jiaqi_repos/soundstorm_ckpts
4
+ log_dir: /gluster-tts/jiaqi_repos/tmp_checkpoints
5
+ max_tokens: 9000
6
+
7
+ # fixed params cosyvoice
8
+ # sample_rate: 22050
9
+ text_encoder_input_size: 512
10
+ llm_input_size: 1536
11
+ llm_output_size: 1536
12
+ spk_embed_dim: 192
13
+
14
+ transformer_model:
15
+ _target_: models.tts.bpe_text2semantic.llm.TransformerLM
16
+ text_encoder_input_size: ${text_encoder_input_size}
17
+ llm_input_size: ${llm_input_size}
18
+ llm_output_size: ${llm_output_size}
19
+ text_token_size: 51866
20
+ speech_token_size: 8192
21
+ length_normalized_loss: true
22
+ lsm_weight: 0
23
+ spk_embed_dim: ${spk_embed_dim}
24
+ text_encoder:
25
+ _target_: cosyvoice.transformer.encoder.ConformerEncoder
26
+ input_size: ${text_encoder_input_size}
27
+ output_size: 1024
28
+ attention_heads: 16
29
+ linear_units: 4096
30
+ num_blocks: 6
31
+ dropout_rate: 0.1
32
+ positional_dropout_rate: 0.1
33
+ attention_dropout_rate: 0
34
+ normalize_before: true
35
+ input_layer: 'linear'
36
+ pos_enc_layer_type: 'rel_pos_espnet'
37
+ selfattention_layer_type: 'rel_selfattn'
38
+ use_cnn_module: false
39
+ macaron_style: false
40
+ use_dynamic_chunk: false
41
+ use_dynamic_left_chunk: false
42
+ static_chunk_size: 1
43
+ llm:
44
+ _target_: cosyvoice.transformer.encoder.TransformerEncoder
45
+ input_size: ${llm_input_size}
46
+ output_size: ${llm_output_size}
47
+ attention_heads: 16
48
+ linear_units: 4096
49
+ num_blocks: 12
50
+ dropout_rate: 0.1
51
+ positional_dropout_rate: 0.1
52
+ attention_dropout_rate: 0
53
+ input_layer: 'linear_legacy'
54
+ pos_enc_layer_type: 'rel_pos_espnet'
55
+ selfattention_layer_type: 'rel_selfattn'
56
+ static_chunk_size: 1
57
+
58
+
59
+ args:
60
+ exp_name: text2semantic
61
+ log_level: DEBUG
62
+ seed: 22
63
+ resume: false
64
+ resume_type: resume
65
+ resume_from_ckpt_path: ""
66
+ preprocess_cfg:
67
+ w2v_path: ${w2v_path}
68
+ preprocess:
69
+ sample_rate: 16000
70
+ min_dur: 3
71
+ max_dur: 30
72
+ hop_size: 320
73
+ cfg:
74
+ w2v_path: ${w2v_path}
75
+ log_dir: ${log_dir}
76
+ dataset:
77
+ _target_: models.tts.text2semantic.emilia_dataset.T2SDataset
78
+ cache_folder: ${dataset_folder}/
79
+ cfg: ${preprocess_cfg}
80
+ mnt_path: ${dataset_folder}/output_gzips/
81
+ # collator:
82
+ # _target_: models.tts.text2semantic.emilia_dataset.T2SCollator
83
+ collator:
84
+ _target_: models.tts.bpe_text2semantic.collator.T2SCollatorDynamic
85
+ max_tokens: 13000
86
+ tokenizer:
87
+ _target_: whisper.tokenizer.get_tokenizer
88
+ multilingual: True
89
+ num_languages: 100
90
+ language: 'en'
91
+ task: 'transcribe'
92
+ train:
93
+ gradient_accumulation_step: 1
94
+ find_unused_parameters: true
95
+ tracker: tensorboard
96
+ max_epoch: 1000
97
+ save_checkpoint_stride:
98
+ - 2000
99
+ keep_last: [1]
100
+ run_eval: true
101
+ dataloader:
102
+ num_worker: 0
103
+ pin_memory: false
104
+ persistent_workers: false
105
+ use_dynamic_batchsize: true
106
+ optimizer: adamW
107
+ adamw:
108
+ lr: 2e-4
109
+ scheduler:
110
+ warmup_steps: 8000
111
+ total_steps: 400000
112
+ min_lr: 5e-5
113
+ exponentiallr:
114
+ gamma: 0.999999
115
+ batch_size: 10
116
+ max_tokens: ${max_tokens}
117
+ max_sentences: 64
118
+ model: ${transformer_model}
119
+ kmeans:
120
+ type: repcodec
121
+ w2v_path: ${w2v_path}
122
+ stat_mean_var_path: ${ckpt_root_path}/semantic_kmeans/emilia_wav2vec2bert_stats_10k.pt
123
+ repcodec:
124
+ codebook_size: 8192
125
+ hidden_size: 1024
126
+ codebook_dim: 8
127
+ vocos_dim: 384
128
+ vocos_intermediate_dim: 2048
129
+ vocos_num_layers: 12
130
+ pretrained_path: ${ckpt_root_path}/repcodec_emilia_50k_8192_norm_8d/86k_steps/model.safetensors
131
+
132
+
133
+ trainer:
134
+ _target_: models.tts.bpe_text2semantic.t2s_trainer.T2STrainer
135
+ args: ${args}
136
+ cfg: ${cfg}
models/tts/valle_v2.1/emilia_dataset_whole.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import oss2 #pip install oss2
2
+ import io
3
+ import librosa
4
+ import torch
5
+ import json
6
+ import tqdm
7
+ import numpy as np
8
+ import logging
9
+ import pickle
10
+ import os
11
+ import time
12
+ from torch.utils.data import Dataset
13
+ from utils.g2p.g2p import phonemizer_g2p
14
+ from multiprocessing import Pool
15
+ import concurrent.futures
16
+ from pathlib import Path
17
+
18
+ from models.base.new_trainer import MainProcessLogger
19
+
20
+ # class PhonemizerWarningFilter(logging.Filter):
21
+ # def filter(self, record):
22
+ # # 只过滤 phonemizer 中的 WARNING 级别日志
23
+ # if record.name == 'phonemizer' and record.levelno == logging.WARNING:
24
+ # return False
25
+ # return False
26
+
27
+ # logger = logging.getLogger('phonemizer')
28
+ # filter = PhonemizerWarningFilter()
29
+ # logger.addFilter(filter)
30
+ # logging.basicConfig(level=logging.INFO)
31
+ # logger = logging.getLogger(__name__)
32
+ logger = MainProcessLogger(is_main_process=False)
33
+
34
+ os.environ['PHONEMIZER_ESPEAK_LIBRARY'] = '/usr/lib/x86_64-linux-gnu/libespeak-ng.so.1'
35
+ os.environ['PHONEMIZER_ESPEAK_PATH'] = '/usr/bin/espeak-ng'
36
+
37
+ LANG2CODE = {
38
+ 'zh': 349,
39
+ 'en': 350,
40
+ 'ja': 351,
41
+ 'ko': 352,
42
+ 'fr': 353,
43
+ 'de': 354,
44
+ }
45
+
46
+ AK = "LTAI5tJU3mNZASp8kUwWFjcq"
47
+ SK = "Ukhy7qWtMgwYVIMJSK3LTBpi1MLYrd"
48
+ bucket_name = "pjlab-3090-openmmlabpartner"
49
+ MOUNT_PATH = "/mnt/data/oss_beijing/"
50
+ data_json_path = '/mnt/petrelfs/hehaorui/jiaqi/Emilia-44.7k.json.gz'
51
+ num_token_per_second = 75
52
+ default_sr = 24000 # it may need to change sampling rate
53
+ duration_setting = {'min': 4, 'max': 20}
54
+
55
+ class EmiliaDataset(Dataset):
56
+ def __init__(self,
57
+ access_key_id=AK,
58
+ access_key_secret=SK,
59
+ bucket_name=bucket_name,
60
+ cache_type='path',
61
+ **kwargs): # 'path' or 'meta'
62
+ self.cache_type = cache_type
63
+
64
+ # Initialize OSS client
65
+ self.init_client(access_key_id, access_key_secret, bucket_name)
66
+ self.json_paths = []
67
+ self.wav_paths = []
68
+ self.language_list = ['en'] # Data language list
69
+ self.wav_path_index2duration = []
70
+ self.wav_path_index2phonelen = []
71
+ self.index2num_frames = []
72
+
73
+ self.json_path2meta = {}
74
+ self.json2filtered_idx = {}
75
+
76
+ self.cache_folder = '/mnt/petrelfs/hehaorui/jiaqi/tmp/emilia-cache-en'
77
+ Path(self.cache_folder).mkdir(parents=True, exist_ok=True)
78
+
79
+ self.wav_paths_cache = os.path.join(self.cache_folder, "wav_paths_cache.pkl")
80
+ self.json_paths_cache = os.path.join(self.cache_folder, "json_paths_cache.pkl")
81
+ self.duration_cache = os.path.join(self.cache_folder, "duration_cache.pkl")
82
+ self.phone_count_cache = os.path.join(self.cache_folder, "phone_count_cache.pkl")
83
+ self.json_path2meta_cache = os.path.join(self.cache_folder, "json_path2meta.pkl")
84
+
85
+ if cache_type == 'path':
86
+ if os.path.exists(self.wav_paths_cache) and os.path.exists(self.json_paths_cache) and os.path.exists(self.duration_cache) and os.path.exists(self.phone_count_cache):
87
+ self.load_cached_paths()
88
+ else:
89
+ logger.info("No cache exists")
90
+ self.get_all_paths_from_json(data_json_path)
91
+ self.save_cached_paths()
92
+ elif cache_type == 'meta':
93
+ if os.path.exists(self.wav_paths_cache) and os.path.exists(self.json_paths_cache):
94
+ self.load_cached_paths()
95
+ else:
96
+ logger.info("No cache exists")
97
+ self.get_all_paths_from_json(data_json_path)
98
+ self.save_cached_paths()
99
+ else:
100
+ logger.info("Incorrect cache loading way")
101
+ exit()
102
+
103
+ if cache_type == 'meta':
104
+ if os.path.exists(self.json_path2meta_cache):
105
+ self.load_path2meta()
106
+ else:
107
+ self.get_jsoncache_multiprocess(pool_size=8)
108
+
109
+ self.num_frame_indices = np.array(sorted(range(len(self.index2num_frames)), key=lambda k: self.index2num_frames[k]))
110
+
111
+
112
+ def init_client(self, access_key_id, access_key_secret, bucket_name):
113
+
114
+ logger.info("Start to initialize OSS client")
115
+ self.auth = oss2.Auth(access_key_id, access_key_secret)
116
+ self.bucket = oss2.Bucket(self.auth, "https://oss-cn-beijing.aliyuncs.com", bucket_name)
117
+ logger.info("OSS client initialized successfully")
118
+
119
+ def load_cached_paths(self):
120
+ logger.info("Loaded paths from cache files")
121
+ with open(self.wav_paths_cache, "rb") as f:
122
+ self.wav_paths = pickle.load(f)
123
+ with open(self.json_paths_cache, "rb") as f:
124
+ self.json_paths = pickle.load(f)
125
+ if self.cache_type == 'path':
126
+ with open(self.duration_cache, "rb") as f:
127
+ self.wav_path_index2duration = pickle.load(f)
128
+ with open(self.phone_count_cache, "rb") as f:
129
+ self.wav_path_index2phonelen = pickle.load(f)
130
+ for duration, phone_count in zip(self.wav_path_index2duration, self.wav_path_index2phonelen):
131
+ self.index2num_frames.append(duration * num_token_per_second + phone_count)
132
+ logger.info("All paths got successfully")
133
+ logger.info("Number of wavs: %d, Number of jsons: %d"
134
+ % (len(self.wav_paths), len(self.json_paths)))
135
+
136
+ def save_cached_paths(self):
137
+ with open(self.wav_paths_cache, "wb") as f:
138
+ pickle.dump(self.wav_paths, f)
139
+ with open(self.json_paths_cache, "wb") as f:
140
+ pickle.dump(self.json_paths, f)
141
+ if self.cache_type == 'path':
142
+ with open(self.duration_cache, "wb") as f:
143
+ pickle.dump(self.wav_path_index2duration, f)
144
+ with open(self.phone_count_cache, "wb") as f:
145
+ pickle.dump(self.wav_path_index2phonelen, f)
146
+ logger.info("Saved paths to cache files")
147
+
148
+ # Load JSON data from a compressed GZIP file
149
+ def load_compressed_json(self, filename):
150
+ import gzip
151
+ with gzip.open(filename, "rt", encoding="utf-8") as f:
152
+ return json.load(f)
153
+
154
+ def get_path_from_json(self, data):
155
+ if data['language'][0] not in self.language_list:
156
+ return
157
+ self.json_paths.append(data['json_path'])
158
+ is_exists = True
159
+ try:
160
+ if not self.bucket.object_exists(data['wav_path'][0]):
161
+ is_exists = False
162
+ except oss2.api.Exception as e:
163
+ is_exists = False
164
+ remove_idx = []
165
+ for wav, duration, phone_count in zip(data['wav_path'], data['duration'], data['phone_count']):
166
+ if duration < duration_setting['min'] or duration > duration_setting['max']:
167
+ idx = wav.split("_")[-1].split(".")[0]
168
+ remove_idx.append(idx)
169
+ continue
170
+ if is_exists:
171
+ self.wav_paths.append(wav)
172
+ else:
173
+ if '.mp3' in wav:
174
+ wav = wav.replace('.mp3', '.wav')
175
+ self.wav_paths.append(wav)
176
+ else:
177
+ wav = wav.replace('.wav', '.mp3')
178
+ self.wav_paths.append(wav)
179
+ self.wav_path_index2duration.append(duration)
180
+ self.wav_path_index2phonelen.append(phone_count)
181
+ self.index2num_frames.append(duration * num_token_per_second + phone_count)
182
+
183
+ self.json2filtered_idx[data['json_path']] = [int(i) for i in data['filtered_idx'].split(',') if i not in remove_idx]
184
+ if not self.json2filtered_idx[data['json_path']]:
185
+ self.json_paths.pop()
186
+
187
+ def get_all_paths_from_json(self, json_path):
188
+
189
+ data_list = self.load_compressed_json(json_path)
190
+ with concurrent.futures.ThreadPoolExecutor() as executor:
191
+ futures = [executor.submit(self.get_path_from_json, data) for data in tqdm.tqdm(data_list)]
192
+ data = [future.result() for future in tqdm.tqdm(futures)]
193
+
194
+ # Only 'meta' cache type use
195
+ def get_phone_count_and_duration(self, meta, idx_list):
196
+ new_meta = {}
197
+ if meta[0]['language'] not in self.language_list:
198
+ new_meta['0'] = meta[0]
199
+ return new_meta
200
+ text_list = []
201
+ for i in idx_list:
202
+ text_list.append(meta[i]['text'])
203
+ token_id = self.g2p(text_list, meta[0]['language'])[1]
204
+ for i, token in zip(idx_list, token_id):
205
+ nm = {}
206
+ nm['language'] = meta[i]['language']
207
+ nm['phone_id'] = token
208
+ nm['phone_count'] = len(token)
209
+ nm['duration'] = meta[i]['end'] - meta[i]['start']
210
+ new_meta[str(i)] = nm
211
+ del meta
212
+ return new_meta
213
+
214
+ # Only 'meta' cache type use
215
+ def process_json_cache(self, json_path):
216
+ default_meta = [{'text': '-1', 'language': 'others'}]
217
+ try:
218
+ file_bytes = self.bucket.get_object(json_path)
219
+ buffer = io.BytesIO(file_bytes.read())
220
+ json_cache = json.load(buffer)
221
+ del buffer, file_bytes
222
+ if json_cache is None:
223
+ logger.info("json is none")
224
+ elif isinstance(json_cache, (dict, list)) and not json_cache:
225
+ logger.info("json is none")
226
+ else:
227
+ return json_cache
228
+ except oss2.exceptions.NoSuchKey as e:
229
+ logger.info(
230
+ "Not found: http_status={0}, request_id={1}".format(e.status, e.request_id))
231
+ except Exception as e:
232
+ logger.info("Error json: {} error: {}".format(json_path, e))
233
+ return default_meta
234
+
235
+ # Only 'meta' cache type use
236
+ def get_jsoncache_multiprocess(self, pool_size):
237
+ logger.info("Start to build json pool")
238
+ logger.info("Start to get json cache")
239
+ json2meta = []
240
+ json_data = []
241
+ tmp_json_cache = os.path.join(self.cache_folder, 'json_cache.pkl')
242
+ if os.path.exists(tmp_json_cache):
243
+ with open(tmp_json_cache, 'rb') as f:
244
+ json_data = pickle.load(f)
245
+ logging.info("Load json_cache.pkl")
246
+ else:
247
+ with concurrent.futures.ThreadPoolExecutor(max_workers=pool_size) as executor:
248
+ futures = [executor.submit(self.process_json_cache, path) for path in self.json_paths]
249
+ json_data = [future.result() for future in tqdm.tqdm(futures)]
250
+ with open(tmp_json_cache, 'wb') as f:
251
+ pickle.dump(json_data, f)
252
+ logging.info("Save json_cache.pkl")
253
+ logging.info("Get meta from cache")
254
+ for json, path in tqdm.tqdm(zip(json_data, self.json_paths), total=len(json_data)):
255
+ # print(json)
256
+ json2meta.append(self.get_phone_count_and_duration(json, self.json2filtered_idx[path]))
257
+ error_json_path_list = []
258
+ for i in range(len(json2meta)):
259
+ if not json2meta[i]:
260
+ error_json_path_list.append(self.json_paths[i])
261
+ elif json2meta[i][next(iter(json2meta[i]))]['language'] not in self.language_list:
262
+ language = json2meta[i][next(iter(json2meta[i]))]['language']
263
+ logger.info("{} is not in language list".format(language))
264
+ error_json_path_list.append(self.json_paths[i])
265
+ else:
266
+ self.json_path2meta[self.json_paths[i]] = json2meta[i]
267
+ logger.info("Remove error json path {}".format(error_json_path_list))
268
+ error_wav_path_list = []
269
+ for error in tqdm.tqdm(error_json_path_list):
270
+ self.json_paths.remove(error)
271
+ error = error.split('.json')[0]
272
+ for wav in self.wav_paths:
273
+ if error in wav:
274
+ error_wav_path_list.append(wav)
275
+ logger.info("Remove error wav path {}".format(error_wav_path_list))
276
+ for error in tqdm.tqdm(error_wav_path_list):
277
+ self.wav_paths.remove(error)
278
+ logger.info("Update cache")
279
+ with open(self.wav_paths_cache, "wb") as f:
280
+ pickle.dump(self.wav_paths, f)
281
+ with open(self.json_paths_cache, "wb") as f:
282
+ pickle.dump(self.json_paths, f)
283
+ with open(self.json_path2meta_cache, "wb") as f:
284
+ pickle.dump(self.json_path2meta, f)
285
+ logger.info("Json cache write to json_path2meta.pkl successfully")
286
+ del json2meta, error_wav_path_list, error_json_path_list
287
+
288
+ # Only 'meta' cache type use
289
+ def load_path2meta(self):
290
+ logger.info("Loaded meta from cache files")
291
+ self.json_path2meta = pickle.load(open(self.json_path2meta_cache, "rb"))
292
+ for path in self.wav_paths:
293
+ meta = self.get_meta_from_wav_path(path)
294
+ duration = meta['duration']
295
+ phone_count = meta['phone_count']
296
+ self.wav_path_index2duration.append(duration)
297
+ self.wav_path_index2phonelen.append(phone_count)
298
+ self.index2num_frames.append(duration * num_token_per_second + phone_count)
299
+
300
+ def get_meta_from_wav_path(self, wav_path):
301
+ index = int(wav_path.split("_")[-1].split(".")[0])
302
+ audio_name = "_".join(wav_path.split("/")[-1].split("_")[:-1])
303
+ dir_name = "/".join(wav_path.split("/")[:-1])
304
+ json_name = audio_name + ".json"
305
+ json_path = dir_name + "/" + json_name
306
+ meta = None
307
+ if self.cache_type == 'meta':
308
+ meta = self.json_path2meta[json_path][str(index)]
309
+ return meta
310
+ elif self.cache_type == 'path':
311
+ try:
312
+ file_bytes = self.bucket.get_object(json_path)
313
+ buffer = io.BytesIO(file_bytes.read())
314
+ meta = json.load(buffer)[index]
315
+ except oss2.exceptions.NoSuchKey as e:
316
+ logger.info(
317
+ "Not found: http_status={0}, request_id={1}".format(e.status, e.request_id))
318
+ except Exception as e:
319
+ logger.info("Error json: {} error: {}".format(json_path, e))
320
+ del index, audio_name, dir_name, json_name, json_path
321
+ return meta
322
+
323
+ def g2p(self, text, language):
324
+ return phonemizer_g2p(text, language)
325
+
326
+ def get_num_frames(self, index):
327
+ return self.wav_path_index2duration[index] * num_token_per_second + self.wav_path_index2phonelen[index]
328
+
329
+ def __len__(self):
330
+ return self.wav_paths.__len__()
331
+
332
+ def __getitem__(self, idx):
333
+
334
+ wav_path = self.wav_paths[idx]
335
+ file_bytes = None
336
+ position = np.where(self.num_frame_indices == idx)[0][0]
337
+ try:
338
+ random_index = np.random.choice(self.num_frame_indices[:position])
339
+ except:
340
+ random_index = np.random.choice(self.num_frame_indices)
341
+ del position
342
+ try:
343
+ for i in range(2):
344
+ try:
345
+ file_bytes = self.bucket.get_object(wav_path.replace("_new", ""))
346
+ break
347
+ except Exception as e:
348
+ logger.info(f"[Filter meta func] Error is {e}")
349
+ time.sleep(i)
350
+ logger.info("retry")
351
+ except:
352
+ logger.info("Get data from oss failed. Get another.")
353
+ return self.__getitem__(random_index)
354
+
355
+ meta = self.get_meta_from_wav_path(wav_path)
356
+ if file_bytes is not None and meta is not None:
357
+ try:
358
+ buffer = io.BytesIO(file_bytes.read())
359
+ speech, sr = librosa.load(buffer, sr=default_sr)
360
+ except:
361
+ return self.__getitem__(random_index)
362
+ assert sr == 24000
363
+
364
+ shape = speech.shape
365
+ pad_shape = ((shape[0] // 320) + 1) * 320 - shape[0]
366
+ speech = np.pad(speech, (0, pad_shape), mode='constant')
367
+ del buffer, pad_shape, shape
368
+ if speech.shape[0] < default_sr * duration_setting['min'] and speech.shape[0] > default_sr * duration_setting['max']:
369
+ logger.info("Wav length exceeds the requirement")
370
+ return self.__getitem__(random_index)
371
+ else:
372
+ speech_tensor = torch.tensor(speech, dtype=torch.float32)
373
+
374
+ phone_id = self.g2p(meta['text'], meta['language'])[1] if self.cache_type == 'path' else meta['phone_id']
375
+ phone_id = torch.tensor(phone_id, dtype=torch.long)
376
+ phone_id = torch.cat([torch.tensor(LANG2CODE[meta['language']], dtype=torch.long).reshape(1), phone_id]) # add language token
377
+ return dict(
378
+ speech=speech_tensor,
379
+ phone=phone_id,
380
+ text=meta['text'],
381
+ language=meta['language'],
382
+ )
383
+ else:
384
+ logger.info("Failed to get file after retries.")
385
+ return self.__getitem__(random_index)
386
+
387
+ if __name__ == '__main__':
388
+
389
+ dataset = EmiliaDataset(AK, SK, bucket_name)
390
+ # print(dataset.__getitem__(0))
391
+ for batch in dataset:
392
+ breakpoint()
393
+ print()
models/tts/valle_v2.1/g2p_processor.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import numpy as np
8
+ import os
9
+ import torch
10
+ import copy
11
+ from g2p_en import G2p
12
+ import re
13
+ import unicodedata
14
+ from g2p_en import G2p
15
+ from g2p_en.expand import normalize_numbers
16
+
17
+ g2p = G2p()
18
+
19
+ PHONE_SET = [
20
+ "!",
21
+ ",",
22
+ ".",
23
+ ".B",
24
+ ":",
25
+ "<BOS>",
26
+ "<EOS>",
27
+ "<PAD>",
28
+ "<UNK>",
29
+ "?",
30
+ "AA0B",
31
+ "AA0E",
32
+ "AA0I",
33
+ "AA1B",
34
+ "AA1E",
35
+ "AA1I",
36
+ "AA2B",
37
+ "AA2E",
38
+ "AA2I",
39
+ "AE0B",
40
+ "AE0E",
41
+ "AE0I",
42
+ "AE1B",
43
+ "AE1E",
44
+ "AE1I",
45
+ "AE2B",
46
+ "AE2E",
47
+ "AE2I",
48
+ "AH0B",
49
+ "AH0E",
50
+ "AH0I",
51
+ "AH1B",
52
+ "AH1E",
53
+ "AH1I",
54
+ "AH2B",
55
+ "AH2E",
56
+ "AH2I",
57
+ "AO0B",
58
+ "AO0E",
59
+ "AO0I",
60
+ "AO1",
61
+ "AO1B",
62
+ "AO1E",
63
+ "AO1I",
64
+ "AO2B",
65
+ "AO2E",
66
+ "AO2I",
67
+ "AW0B",
68
+ "AW0E",
69
+ "AW0I",
70
+ "AW1B",
71
+ "AW1E",
72
+ "AW1I",
73
+ "AW2B",
74
+ "AW2E",
75
+ "AW2I",
76
+ "AY0B",
77
+ "AY0E",
78
+ "AY0I",
79
+ "AY1B",
80
+ "AY1E",
81
+ "AY1I",
82
+ "AY2B",
83
+ "AY2E",
84
+ "AY2I",
85
+ "BB",
86
+ "BE",
87
+ "BI",
88
+ "CHB",
89
+ "CHE",
90
+ "CHI",
91
+ "DB",
92
+ "DE",
93
+ "DHB",
94
+ "DHE",
95
+ "DHI",
96
+ "DI",
97
+ "EH0B",
98
+ "EH0E",
99
+ "EH0I",
100
+ "EH1B",
101
+ "EH1E",
102
+ "EH1I",
103
+ "EH2B",
104
+ "EH2E",
105
+ "EH2I",
106
+ "ER0B",
107
+ "ER0E",
108
+ "ER0I",
109
+ "ER1B",
110
+ "ER1E",
111
+ "ER1I",
112
+ "ER2B",
113
+ "ER2E",
114
+ "ER2I",
115
+ "EY0B",
116
+ "EY0E",
117
+ "EY0I",
118
+ "EY1B",
119
+ "EY1E",
120
+ "EY1I",
121
+ "EY2B",
122
+ "EY2E",
123
+ "EY2I",
124
+ "FB",
125
+ "FE",
126
+ "FI",
127
+ "GB",
128
+ "GE",
129
+ "GI",
130
+ "HHB",
131
+ "HHE",
132
+ "HHI",
133
+ "IH0B",
134
+ "IH0E",
135
+ "IH0I",
136
+ "IH1B",
137
+ "IH1E",
138
+ "IH1I",
139
+ "IH2B",
140
+ "IH2E",
141
+ "IH2I",
142
+ "IY0B",
143
+ "IY0E",
144
+ "IY0I",
145
+ "IY1B",
146
+ "IY1E",
147
+ "IY1I",
148
+ "IY2B",
149
+ "IY2E",
150
+ "IY2I",
151
+ "JHB",
152
+ "JHE",
153
+ "JHI",
154
+ "KB",
155
+ "KE",
156
+ "KI",
157
+ "L",
158
+ "LB",
159
+ "LE",
160
+ "LI",
161
+ "MB",
162
+ "ME",
163
+ "MI",
164
+ "NB",
165
+ "NE",
166
+ "NGB",
167
+ "NGE",
168
+ "NGI",
169
+ "NI",
170
+ "OW0B",
171
+ "OW0E",
172
+ "OW0I",
173
+ "OW1B",
174
+ "OW1E",
175
+ "OW1I",
176
+ "OW2B",
177
+ "OW2E",
178
+ "OW2I",
179
+ "OY0B",
180
+ "OY0E",
181
+ "OY0I",
182
+ "OY1B",
183
+ "OY1E",
184
+ "OY1I",
185
+ "OY2B",
186
+ "OY2E",
187
+ "OY2I",
188
+ "PB",
189
+ "PE",
190
+ "PI",
191
+ "RB",
192
+ "RE",
193
+ "RI",
194
+ "SB",
195
+ "SE",
196
+ "SHB",
197
+ "SHE",
198
+ "SHI",
199
+ "SI",
200
+ "TB",
201
+ "TE",
202
+ "THB",
203
+ "THE",
204
+ "THI",
205
+ "TI",
206
+ "UH0B",
207
+ "UH0E",
208
+ "UH0I",
209
+ "UH1B",
210
+ "UH2B",
211
+ "UH1E",
212
+ "UH1I",
213
+ "UH2E",
214
+ "UH2I",
215
+ "UW0B",
216
+ "UW0E",
217
+ "UW0I",
218
+ "UW1B",
219
+ "UW1E",
220
+ "UW1I",
221
+ "UW2B",
222
+ "UW2E",
223
+ "UW2I",
224
+ "VB",
225
+ "VE",
226
+ "VI",
227
+ "WB",
228
+ "WE",
229
+ "WI",
230
+ "YB",
231
+ "YE",
232
+ "YI",
233
+ "ZB",
234
+ "ZE",
235
+ "ZHB",
236
+ "ZHE",
237
+ "ZHI",
238
+ "ZI",
239
+ "|",
240
+ ]
241
+ PHPONE2ID = {PHONE_SET[i]: i for i in range(len(PHONE_SET))}
242
+
243
+ PUNCS = "!,.?;:"
244
+
245
+
246
+ def is_sil_phoneme(p):
247
+ return p == "" or not p[0].isalpha()
248
+
249
+
250
+ def add_bdr(txt_struct):
251
+ txt_struct_ = []
252
+ for i, ts in enumerate(txt_struct):
253
+ txt_struct_.append(ts)
254
+ if (
255
+ i != len(txt_struct) - 1
256
+ and not is_sil_phoneme(txt_struct[i][0])
257
+ and not is_sil_phoneme(txt_struct[i + 1][0])
258
+ ):
259
+ txt_struct_.append(["|", ["|"]])
260
+ return txt_struct_
261
+
262
+
263
+ def preprocess_text(text):
264
+ text = normalize_numbers(text)
265
+ text = "".join(
266
+ char
267
+ for char in unicodedata.normalize("NFD", text)
268
+ if unicodedata.category(char) != "Mn"
269
+ ) # Strip accents
270
+ text = text.lower()
271
+ text = re.sub("['\"()]+", "", text)
272
+ text = re.sub("[-]+", " ", text)
273
+ text = re.sub(f"[^ a-z{PUNCS}]", "", text)
274
+ text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> !
275
+ text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> !
276
+ text = text.replace("i.e.", "that is")
277
+ text = text.replace("i.e.", "that is")
278
+ text = text.replace("etc.", "etc")
279
+ text = re.sub(f"([{PUNCS}])", r" ", text) # remove punctuations for now
280
+ text = re.sub(rf"\s+", r" ", text)
281
+ return text
282
+
283
+
284
+ def postprocess(txt_struct):
285
+ while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[0][0]):
286
+ txt_struct = txt_struct[1:]
287
+ while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[-1][0]):
288
+ txt_struct = txt_struct[:-1]
289
+ txt_struct = add_bdr(txt_struct)
290
+ txt_struct = [["<BOS>", ["<BOS>"]]] + txt_struct + [["<EOS>", ["<EOS>"]]]
291
+ return txt_struct
292
+
293
+
294
+ def process(txt, g2p):
295
+ txt = preprocess_text(txt).strip()
296
+ phs = g2p(txt)
297
+ txt_struct = [[w, []] for w in txt.split(" ")]
298
+ i_word = 0
299
+ for p in phs:
300
+ if p == " ":
301
+ i_word += 1
302
+ else:
303
+ txt_struct[i_word][1].append(p)
304
+
305
+ txt_struct_ret = copy.deepcopy(txt_struct)
306
+
307
+ for i_word in range(len(txt_struct)):
308
+ if not is_sil_phoneme(txt_struct[i_word][0]):
309
+ if len(txt_struct[i_word][1]) > 1:
310
+ txt_struct_ret[i_word][1][0] += "B"
311
+ for i in range(1, len(txt_struct[i_word][1]) - 1):
312
+ txt_struct_ret[i_word][1][i] += "I"
313
+ txt_struct_ret[i_word][1][-1] += "E"
314
+ else:
315
+ txt_struct_ret[i_word][1][0] += "B"
316
+
317
+ txt_struct_ret = postprocess(txt_struct_ret)
318
+
319
+ return txt_struct_ret, txt
320
+
321
+
322
+ def test():
323
+ g2p = G2p()
324
+ txt = "This is a test sentence."
325
+ txt_struct, txt = process(txt, g2p)
326
+ print(txt_struct)
327
+ print(txt)
328
+ phone_seq = [p for w in txt_struct for p in w[1]]
329
+ print(phone_seq)
330
+ phone_id = [PHPONE2ID[p] for p in phone_seq]
331
+ print(phone_id)
332
+
333
+
334
+ class G2pProcessor:
335
+ def __init__(self):
336
+ self.g2p = G2p()
337
+
338
+ def __call__(self, txt, lang="en"):
339
+ return self.txt2phoneid(txt)
340
+
341
+ def txt2phoneid(self, txt):
342
+ txt_struct, txt = process(txt, self.g2p)
343
+ phone_seq = [p for w in txt_struct for p in w[1]]
344
+ phone_id = [PHPONE2ID[p] for p in phone_seq]
345
+ return None, phone_id
346
+
347
+ def phoneid2txt(self, phone_id):
348
+ txt = []
349
+ for i in phone_id:
350
+ txt.append(PHONE_SET[i])
351
+ return txt
352
+
353
+
354
+ if __name__ == "__main__":
355
+ g2p = G2pProcessor()
356
+ txt = "This is a test sentence."
357
+ phoneid = g2p.txt2phoneid(txt)[1]
358
+ # output: [5, 73, 118, 175, 218, 116, 213, 218, 28, 218, 180, 82, 179, 181, 218, 174, 82, 149, 185, 30, 149, 175, 6]
359
+ # print(phoneid)
360
+ print(g2p.phoneid2txt(phoneid))
361
+ # output: ['<BOS>', 'DHB', 'IH1I', 'SE', '|', 'IH1B', 'ZE', '|', 'AH0B', '|', 'TB', 'EH1I', 'SI', 'TE', '|', 'SB', 'EH1I', 'NI', 'TI', 'AH0I', 'NI', 'SE', '<EOS>']
362
+ print(len(PHONE_SET))
363
+ # output: 219
models/tts/valle_v2.1/libritts_dataset.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+ from tqdm import tqdm
11
+ from g2p_en import G2p
12
+ import librosa
13
+ from torch.utils.data import Dataset
14
+ import pandas as pd
15
+ import time
16
+ import io
17
+
18
+ SAMPLE_RATE = 16000
19
+ # g2p
20
+ from .g2p_processor import G2pProcessor
21
+
22
+ phonemizer_g2p = G2pProcessor()
23
+
24
+
25
+ class VALLEDataset(Dataset):
26
+ def __init__(self, args):
27
+ print(f"Initializing VALLEDataset")
28
+ self.dataset_list = args.dataset_list
29
+
30
+ print(f"using sampling rate {SAMPLE_RATE}")
31
+
32
+ # set dataframe clumn name
33
+ book_col_name = [
34
+ "ID",
35
+ "Original_text",
36
+ "Normalized_text",
37
+ "Aligned_or_not",
38
+ "Start_time",
39
+ "End_time",
40
+ "Signal_to_noise_ratio",
41
+ ]
42
+ trans_col_name = [
43
+ "ID",
44
+ "Original_text",
45
+ "Normalized_text",
46
+ "Dir_path",
47
+ "Duration",
48
+ ]
49
+ self.metadata_cache = pd.DataFrame(columns=book_col_name)
50
+ self.trans_cache = pd.DataFrame(columns=trans_col_name)
51
+ # dataset_cache_dir = args.cache_dir # cache_dir
52
+ # print(f"args.cache_dir = ", args.cache_dir)
53
+ # os.makedirs(dataset_cache_dir, exist_ok=True)
54
+
55
+ ######## add data dir to dataset2dir ##########
56
+ self.dataset2dir = {
57
+ "dev-clean": f"{args.data_dir}/dev-clean",
58
+ "dev-other": f"{args.data_dir}/dev-other",
59
+ "test-clean": f"{args.data_dir}/test-clean",
60
+ "test-other": f"{args.data_dir}/test-other",
61
+ "train-clean-100": f"{args.data_dir}/train-clean-100",
62
+ "train-clean-360": f"{args.data_dir}/train-clean-360",
63
+ "train-other-500": f"{args.data_dir}/train-other-500",
64
+ }
65
+
66
+ ###### load metadata and transcripts #####
67
+ for dataset_name in self.dataset_list:
68
+ print("Initializing dataset: ", dataset_name)
69
+ # get [book,transcripts,audio] files list
70
+ self.book_files_list = self.get_metadata_files(
71
+ self.dataset2dir[dataset_name]
72
+ )
73
+ self.trans_files_list = self.get_trans_files(self.dataset2dir[dataset_name])
74
+
75
+ ## create metadata_cache (book.tsv file is not filtered, some file is not exist, but contain Duration and Signal_to_noise_ratio)
76
+ print("reading paths for dataset...")
77
+ for book_path in tqdm(self.book_files_list):
78
+ tmp_cache = pd.read_csv(
79
+ book_path, sep="\t", names=book_col_name, quoting=3
80
+ )
81
+ self.metadata_cache = pd.concat(
82
+ [self.metadata_cache, tmp_cache], ignore_index=True
83
+ )
84
+ self.metadata_cache.set_index("ID", inplace=True)
85
+
86
+ ## create transcripts (the trans.tsv file)
87
+ print("creating transcripts for dataset...")
88
+ for trans_path in tqdm(self.trans_files_list):
89
+ tmp_cache = pd.read_csv(
90
+ trans_path, sep="\t", names=trans_col_name, quoting=3
91
+ )
92
+ tmp_cache["Dir_path"] = os.path.dirname(trans_path)
93
+ self.trans_cache = pd.concat(
94
+ [self.trans_cache, tmp_cache], ignore_index=True
95
+ )
96
+ self.trans_cache.set_index("ID", inplace=True)
97
+
98
+ ## calc duration
99
+ self.trans_cache["Duration"] = (
100
+ self.metadata_cache.End_time[self.trans_cache.index]
101
+ - self.metadata_cache.Start_time[self.trans_cache.index]
102
+ )
103
+ ## add fullpath
104
+ # self.trans_cache['Full_path'] = os.path.join(self.dataset2dir[dataset_name],self.trans_cache['ID'])
105
+
106
+ # filter_by_duration: filter_out files with duration < 3.0 or > 15.0
107
+ print(f"Filtering files with duration between 3.0 and 15.0 seconds")
108
+ print(f"Before filtering: {len(self.trans_cache)}")
109
+ self.trans_cache = self.trans_cache[
110
+ (self.trans_cache["Duration"] >= 3.0)
111
+ & (self.trans_cache["Duration"] <= 15.0)
112
+ ]
113
+ print(f"After filtering: {len(self.trans_cache)}")
114
+
115
+ def get_metadata_files(self, directory):
116
+ book_files = []
117
+ for root, _, files in os.walk(directory):
118
+ for file in files:
119
+ if file.endswith(".book.tsv") and file[0] != ".":
120
+ rel_path = os.path.join(root, file)
121
+ book_files.append(rel_path)
122
+ return book_files
123
+
124
+ def get_trans_files(self, directory):
125
+ trans_files = []
126
+ for root, _, files in os.walk(directory):
127
+ for file in files:
128
+ if file.endswith(".trans.tsv") and file[0] != ".":
129
+ rel_path = os.path.join(root, file)
130
+ trans_files.append(rel_path)
131
+ return trans_files
132
+
133
+ def get_audio_files(self, directory):
134
+ audio_files = []
135
+ for root, _, files in os.walk(directory):
136
+ for file in files:
137
+ if file.endswith((".flac", ".wav", ".opus")):
138
+ rel_path = os.path.relpath(os.path.join(root, file), directory)
139
+ audio_files.append(rel_path)
140
+ return audio_files
141
+
142
+ def get_num_frames(self, index):
143
+ # get_num_frames(durations) by index
144
+ duration = self.meta_data_cache["Duration"][index]
145
+ # num_frames = duration * SAMPLE_RATE
146
+ num_frames = int(duration * 75)
147
+
148
+ # file_rel_path = self.meta_data_cache['relpath'][index]
149
+ # uid = file_rel_path.rstrip('.flac').split('/')[-1]
150
+ # num_frames += len(self.transcripts[uid])
151
+ return num_frames
152
+
153
+ def __len__(self):
154
+ return len(self.trans_cache)
155
+
156
+ def __getitem__(self, idx):
157
+ # Get the file rel path
158
+ file_dir_path = self.trans_cache["Dir_path"].iloc[idx]
159
+ # Get uid
160
+ uid = self.trans_cache.index[idx]
161
+ # Get the file name from cache uid
162
+ file_name = uid + ".wav"
163
+ # Get the full file path
164
+ full_file_path = os.path.join(file_dir_path, file_name)
165
+
166
+ # get phone
167
+ phone = self.trans_cache["Normalized_text"][uid]
168
+ phone = phonemizer_g2p(phone, "en")[1]
169
+ # load speech
170
+ speech, _ = librosa.load(full_file_path, sr=SAMPLE_RATE)
171
+ # if self.resample_to_24k:
172
+ # speech = librosa.resample(speech, orig_sr=SAMPLE_RATE, target_sr=24000)
173
+ # speech = torch.tensor(speech, dtype=torch.float32)
174
+ # pad speech to multiples of 200
175
+
176
+ # remainder = speech.size(0) % 200
177
+ # if remainder > 0:
178
+ # pad = 200 - remainder
179
+ # speech = torch.cat([speech, torch.zeros(pad, dtype=torch.float32)], dim=0)
180
+
181
+ # inputs = self._get_reference_vc(speech, hop_length=200)
182
+ inputs = {}
183
+ # Get the speaker id
184
+ # speaker = self.meta_data_cache['speaker'][idx]
185
+ # speaker_id = self.speaker2id[speaker]
186
+ # inputs["speaker_id"] = speaker_id
187
+ inputs["speech"] = speech # 24khz speech, [T]
188
+ inputs["phone"] = phone # [T]
189
+ return inputs
190
+
191
+
192
+ def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
193
+ if len(batch) == 0:
194
+ return 0
195
+ if len(batch) == max_sentences:
196
+ return 1
197
+ if num_tokens > max_tokens:
198
+ return 1
199
+ return 0
200
+
201
+
202
+ def batch_by_size(
203
+ indices,
204
+ num_tokens_fn,
205
+ max_tokens=None,
206
+ max_sentences=None,
207
+ required_batch_size_multiple=1,
208
+ ):
209
+ """
210
+ Yield mini-batches of indices bucketed by size. Batches may contain
211
+ sequences of different lengths.
212
+
213
+ Args:
214
+ indices (List[int]): ordered list of dataset indices
215
+ num_tokens_fn (callable): function that returns the number of tokens at
216
+ a given index
217
+ max_tokens (int, optional): max number of tokens in each batch
218
+ (default: None).
219
+ max_sentences (int, optional): max number of sentences in each
220
+ batch (default: None).
221
+ required_batch_size_multiple (int, optional): require batch size to
222
+ be a multiple of N (default: 1).
223
+ """
224
+ bsz_mult = required_batch_size_multiple
225
+
226
+ sample_len = 0
227
+ sample_lens = []
228
+ batch = []
229
+ batches = []
230
+ for i in range(len(indices)):
231
+ idx = indices[i]
232
+ num_tokens = num_tokens_fn(idx)
233
+ sample_lens.append(num_tokens)
234
+ sample_len = max(sample_len, num_tokens)
235
+
236
+ assert (
237
+ sample_len <= max_tokens
238
+ ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
239
+ idx, sample_len, max_tokens
240
+ )
241
+ num_tokens = (len(batch) + 1) * sample_len
242
+
243
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
244
+ mod_len = max(
245
+ bsz_mult * (len(batch) // bsz_mult),
246
+ len(batch) % bsz_mult,
247
+ )
248
+ batches.append(batch[:mod_len])
249
+ batch = batch[mod_len:]
250
+ sample_lens = sample_lens[mod_len:]
251
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
252
+ batch.append(idx)
253
+ if len(batch) > 0:
254
+ batches.append(batch)
255
+ return batches
256
+
257
+
258
+ def test():
259
+ from utils.util import load_config
260
+
261
+ cfg = load_config("./egs/tts/VALLE_V2/exp_ar_libritts.json")
262
+ dataset = VALLEDataset(cfg.dataset)
263
+ metadata_cache = dataset.metadata_cache
264
+ trans_cache = dataset.trans_cache
265
+ print(trans_cache.head(10))
266
+ # print(dataset.book_files_list)
267
+ breakpoint()
268
+
269
+
270
+ if __name__ == "__main__":
271
+ test()
models/tts/valle_v2.1/modeling_llama.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This code is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
6
+
7
+ # Original work copyright
8
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
11
+ # and OPT implementations in this library. It has been modified from its
12
+ # original forms to accommodate minor architectural differences compared
13
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
14
+ #
15
+ # Licensed under the Apache License, Version 2.0 (the "License");
16
+ # you may not use this file except in compliance with the License.
17
+ # You may obtain a copy of the License at
18
+ #
19
+ # http://www.apache.org/licenses/LICENSE-2.0
20
+ #
21
+ # Unless required by applicable law or agreed to in writing, software
22
+ # distributed under the License is distributed on an "AS IS" BASIS,
23
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24
+ # See the License for the specific language governing permissions and
25
+ # limitations under the License.
26
+ """ PyTorch LLaMA model."""
27
+ import math
28
+ from typing import List, Optional, Tuple, Union
29
+
30
+ import torch
31
+ import torch.utils.checkpoint
32
+ from torch import nn
33
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
34
+
35
+ from transformers.models.llama.modeling_llama import ACT2FN
36
+ from transformers.models.llama.modeling_llama import (
37
+ BaseModelOutputWithPast,
38
+ CausalLMOutputWithPast,
39
+ SequenceClassifierOutputWithPast,
40
+ )
41
+ from transformers.models.llama.modeling_llama import PreTrainedModel
42
+ from transformers.models.llama.modeling_llama import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from transformers.models.llama.modeling_llama import LlamaConfig
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CONFIG_FOR_DOC = "LlamaConfig"
54
+
55
+
56
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
57
+ def _make_causal_mask(
58
+ input_ids_shape: torch.Size,
59
+ dtype: torch.dtype,
60
+ device: torch.device,
61
+ past_key_values_length: int = 0,
62
+ ):
63
+ """
64
+ Make causal mask used for bi-directional self-attention.
65
+ """
66
+ bsz, tgt_len = input_ids_shape
67
+ mask = torch.full(
68
+ (tgt_len, tgt_len),
69
+ torch.tensor(torch.finfo(dtype).min, device=device),
70
+ device=device,
71
+ )
72
+ mask_cond = torch.arange(mask.size(-1), device=device)
73
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
74
+ mask = mask.to(dtype)
75
+
76
+ if past_key_values_length > 0:
77
+ mask = torch.cat(
78
+ [
79
+ torch.zeros(
80
+ tgt_len, past_key_values_length, dtype=dtype, device=device
81
+ ),
82
+ mask,
83
+ ],
84
+ dim=-1,
85
+ )
86
+ return mask[None, None, :, :].expand(
87
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
88
+ )
89
+
90
+
91
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
92
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
93
+ """
94
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
95
+ """
96
+ bsz, src_len = mask.size()
97
+ tgt_len = tgt_len if tgt_len is not None else src_len
98
+
99
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
100
+
101
+ inverted_mask = 1.0 - expanded_mask
102
+
103
+ return inverted_mask.masked_fill(
104
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
105
+ )
106
+
107
+
108
+ class LlamaRMSNorm(nn.Module):
109
+ def __init__(self, hidden_size, eps=1e-6):
110
+ """
111
+ LlamaRMSNorm is equivalent to T5LayerNorm
112
+ """
113
+ super().__init__()
114
+ self.weight = nn.Parameter(torch.ones(hidden_size))
115
+ self.variance_epsilon = eps
116
+
117
+ def forward(self, hidden_states):
118
+ input_dtype = hidden_states.dtype
119
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
120
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
121
+
122
+ return (self.weight * hidden_states).to(input_dtype)
123
+
124
+
125
+ class LlamaRotaryEmbedding(torch.nn.Module):
126
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
127
+ super().__init__()
128
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
129
+ self.register_buffer("inv_freq", inv_freq)
130
+
131
+ # Build here to make `torch.jit.trace` work.
132
+ self.max_seq_len_cached = max_position_embeddings
133
+ t = torch.arange(
134
+ self.max_seq_len_cached,
135
+ device=self.inv_freq.device,
136
+ dtype=self.inv_freq.dtype,
137
+ )
138
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
139
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
140
+ emb = torch.cat((freqs, freqs), dim=-1)
141
+ self.register_buffer(
142
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
143
+ )
144
+ self.register_buffer(
145
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
146
+ )
147
+
148
+ def forward(self, x, seq_len=None):
149
+ # x: [bs, num_attention_heads, seq_len, head_size]
150
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
151
+ if seq_len > self.max_seq_len_cached:
152
+ self.max_seq_len_cached = seq_len
153
+ t = torch.arange(
154
+ self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
155
+ )
156
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
157
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
158
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
159
+ self.register_buffer(
160
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
161
+ )
162
+ self.register_buffer(
163
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
164
+ )
165
+ return (
166
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
167
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
168
+ )
169
+
170
+
171
+ def rotate_half(x):
172
+ """Rotates half the hidden dims of the input."""
173
+ x1 = x[..., : x.shape[-1] // 2]
174
+ x2 = x[..., x.shape[-1] // 2 :]
175
+ return torch.cat((-x2, x1), dim=-1)
176
+
177
+
178
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
179
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
180
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
181
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
182
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
183
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ class LlamaMLP(nn.Module):
190
+ def __init__(
191
+ self,
192
+ hidden_size: int,
193
+ intermediate_size: int,
194
+ hidden_act: str,
195
+ ):
196
+ super().__init__()
197
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
198
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
199
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
200
+ self.act_fn = ACT2FN[hidden_act]
201
+
202
+ def forward(self, x):
203
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
204
+
205
+
206
+ class LlamaAttention(nn.Module):
207
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
208
+
209
+ def __init__(self, config: LlamaConfig, **kwargs):
210
+ super().__init__()
211
+ self.config = config
212
+ self.hidden_size = config.hidden_size
213
+ self.num_heads = config.num_attention_heads
214
+ self.head_dim = self.hidden_size // self.num_heads
215
+ self.max_position_embeddings = config.max_position_embeddings
216
+
217
+ if (self.head_dim * self.num_heads) != self.hidden_size:
218
+ raise ValueError(
219
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
220
+ f" and `num_heads`: {self.num_heads})."
221
+ )
222
+ self.q_proj = nn.Linear(
223
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
224
+ )
225
+ self.k_proj = nn.Linear(
226
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
227
+ )
228
+ self.v_proj = nn.Linear(
229
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
230
+ )
231
+ self.o_proj = nn.Linear(
232
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
233
+ )
234
+ self.rotary_emb = LlamaRotaryEmbedding(
235
+ self.head_dim, max_position_embeddings=self.max_position_embeddings
236
+ )
237
+
238
+ if "layer_idx" in kwargs:
239
+ self.layer_idx = kwargs["layer_idx"]
240
+
241
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
242
+ return (
243
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
244
+ .transpose(1, 2)
245
+ .contiguous()
246
+ )
247
+
248
+ def forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ attention_mask: Optional[torch.Tensor] = None,
252
+ position_ids: Optional[torch.LongTensor] = None,
253
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
254
+ output_attentions: bool = False,
255
+ use_cache: bool = False,
256
+ **kwargs,
257
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
258
+ bsz, q_len, _ = hidden_states.size()
259
+
260
+ query_states = (
261
+ self.q_proj(hidden_states)
262
+ .view(bsz, q_len, self.num_heads, self.head_dim)
263
+ .transpose(1, 2)
264
+ )
265
+ key_states = (
266
+ self.k_proj(hidden_states)
267
+ .view(bsz, q_len, self.num_heads, self.head_dim)
268
+ .transpose(1, 2)
269
+ )
270
+ value_states = (
271
+ self.v_proj(hidden_states)
272
+ .view(bsz, q_len, self.num_heads, self.head_dim)
273
+ .transpose(1, 2)
274
+ )
275
+
276
+ kv_seq_len = key_states.shape[-2]
277
+ if past_key_value is not None:
278
+ kv_seq_len += past_key_value[0].shape[-2]
279
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
280
+ query_states, key_states = apply_rotary_pos_emb(
281
+ query_states, key_states, cos, sin, position_ids
282
+ )
283
+ # [bsz, nh, t, hd]
284
+
285
+ if past_key_value is not None:
286
+ # reuse k, v, self_attention
287
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
288
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
289
+
290
+ past_key_value = (key_states, value_states) if use_cache else None
291
+
292
+ attn_weights = torch.matmul(
293
+ query_states, key_states.transpose(2, 3)
294
+ ) / math.sqrt(self.head_dim)
295
+
296
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
297
+ raise ValueError(
298
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
299
+ f" {attn_weights.size()}"
300
+ )
301
+
302
+ if attention_mask is not None:
303
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
304
+ raise ValueError(
305
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
306
+ )
307
+ attn_weights = attn_weights + attention_mask
308
+ attn_weights = torch.max(
309
+ attn_weights,
310
+ torch.tensor(
311
+ torch.finfo(attn_weights.dtype).min, device=attn_weights.device
312
+ ),
313
+ )
314
+
315
+ unnormed_attn_weights = attn_weights
316
+
317
+ # upcast attention to fp32
318
+ attn_weights = nn.functional.softmax(
319
+ attn_weights, dim=-1, dtype=torch.float32
320
+ ).to(query_states.dtype)
321
+ attn_output = torch.matmul(attn_weights, value_states)
322
+
323
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
324
+ raise ValueError(
325
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
326
+ f" {attn_output.size()}"
327
+ )
328
+
329
+ attn_output = attn_output.transpose(1, 2)
330
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
331
+
332
+ attn_output = self.o_proj(attn_output)
333
+
334
+ if not output_attentions:
335
+ attn_weights = None
336
+
337
+ return attn_output, unnormed_attn_weights, past_key_value
338
+
339
+
340
+ class LlamaDecoderLayer(nn.Module):
341
+ def __init__(self, config: LlamaConfig, **kwargs):
342
+ super().__init__()
343
+ self.hidden_size = config.hidden_size
344
+ self.self_attn = LlamaAttention(config=config)
345
+ self.mlp = LlamaMLP(
346
+ hidden_size=self.hidden_size,
347
+ intermediate_size=config.intermediate_size,
348
+ hidden_act=config.hidden_act,
349
+ )
350
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
+ self.post_attention_layernorm = LlamaRMSNorm(
352
+ config.hidden_size, eps=config.rms_norm_eps
353
+ )
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ position_ids: Optional[torch.LongTensor] = None,
360
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
361
+ output_attentions: Optional[bool] = False,
362
+ use_cache: Optional[bool] = False,
363
+ ) -> Tuple[
364
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
365
+ ]:
366
+ """
367
+ Args:
368
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
369
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
370
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
371
+ output_attentions (`bool`, *optional*):
372
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
373
+ returned tensors for more detail.
374
+ use_cache (`bool`, *optional*):
375
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
376
+ (see `past_key_values`).
377
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
378
+ """
379
+
380
+ residual = hidden_states
381
+
382
+ hidden_states = self.input_layernorm(hidden_states)
383
+
384
+ # Self Attention
385
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
386
+ hidden_states=hidden_states,
387
+ attention_mask=attention_mask,
388
+ position_ids=position_ids,
389
+ past_key_value=past_key_value,
390
+ output_attentions=output_attentions,
391
+ use_cache=use_cache,
392
+ )
393
+ hidden_states = residual + hidden_states
394
+
395
+ # Fully Connected
396
+ residual = hidden_states
397
+ hidden_states = self.post_attention_layernorm(hidden_states)
398
+ hidden_states = self.mlp(hidden_states)
399
+ hidden_states = residual + hidden_states
400
+
401
+ outputs = (hidden_states,)
402
+
403
+ if output_attentions:
404
+ outputs += (self_attn_weights,)
405
+
406
+ if use_cache:
407
+ outputs += (present_key_value,)
408
+
409
+ return outputs
410
+
411
+
412
+ LLAMA_START_DOCSTRING = r"""
413
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
414
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
415
+ etc.)
416
+
417
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
418
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
419
+ and behavior.
420
+
421
+ Parameters:
422
+ config ([`LlamaConfig`]):
423
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
424
+ load the weights associated with the model, only the configuration. Check out the
425
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
426
+ """
427
+
428
+
429
+ @add_start_docstrings(
430
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
431
+ LLAMA_START_DOCSTRING,
432
+ )
433
+ class LlamaPreTrainedModel(PreTrainedModel):
434
+ config_class = LlamaConfig
435
+ base_model_prefix = "model"
436
+ supports_gradient_checkpointing = True
437
+ _no_split_modules = ["LlamaDecoderLayer"]
438
+ _skip_keys_device_placement = "past_key_values"
439
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
440
+
441
+ def _init_weights(self, module):
442
+ std = self.config.initializer_range
443
+ if isinstance(module, nn.Linear):
444
+ module.weight.data.normal_(mean=0.0, std=std)
445
+ if module.bias is not None:
446
+ module.bias.data.zero_()
447
+ elif isinstance(module, nn.Embedding):
448
+ module.weight.data.normal_(mean=0.0, std=std)
449
+ if module.padding_idx is not None:
450
+ module.weight.data[module.padding_idx].zero_()
451
+
452
+ def _set_gradient_checkpointing(self, module, value=False):
453
+ if isinstance(module, LlamaModel):
454
+ module.gradient_checkpointing = value
455
+
456
+
457
+ LLAMA_INPUTS_DOCSTRING = r"""
458
+ Args:
459
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
460
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
461
+ it.
462
+
463
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
464
+ [`PreTrainedTokenizer.__call__`] for details.
465
+
466
+ [What are input IDs?](../glossary#input-ids)
467
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
468
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
469
+
470
+ - 1 for tokens that are **not masked**,
471
+ - 0 for tokens that are **masked**.
472
+
473
+ [What are attention masks?](../glossary#attention-mask)
474
+
475
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
476
+ [`PreTrainedTokenizer.__call__`] for details.
477
+
478
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
479
+ `past_key_values`).
480
+
481
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
482
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
483
+ information on the default strategy.
484
+
485
+ - 1 indicates the head is **not masked**,
486
+ - 0 indicates the head is **masked**.
487
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
488
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
489
+ config.n_positions - 1]`.
490
+
491
+ [What are position IDs?](../glossary#position-ids)
492
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
493
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
494
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
495
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
496
+
497
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
498
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
499
+
500
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
501
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
502
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
503
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
504
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
505
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
506
+ model's internal embedding lookup matrix.
507
+ use_cache (`bool`, *optional*):
508
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
509
+ `past_key_values`).
510
+ output_attentions (`bool`, *optional*):
511
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
512
+ tensors for more detail.
513
+ output_hidden_states (`bool`, *optional*):
514
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
515
+ more detail.
516
+ return_dict (`bool`, *optional*):
517
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
518
+ """
519
+
520
+
521
+ @add_start_docstrings(
522
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
523
+ LLAMA_START_DOCSTRING,
524
+ )
525
+ class LlamaModel(LlamaPreTrainedModel):
526
+ """
527
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
528
+
529
+ Args:
530
+ config: LlamaConfig
531
+ """
532
+
533
+ def __init__(self, config: LlamaConfig):
534
+ super().__init__(config)
535
+ self.padding_idx = config.pad_token_id
536
+ self.vocab_size = config.vocab_size
537
+
538
+ self.embed_tokens = nn.Embedding(
539
+ config.vocab_size, config.hidden_size, self.padding_idx
540
+ )
541
+ self.layers = nn.ModuleList(
542
+ [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
543
+ )
544
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
545
+
546
+ self.gradient_checkpointing = False
547
+ # Initialize weights and apply final processing
548
+ self.post_init()
549
+
550
+ def get_input_embeddings(self):
551
+ return self.embed_tokens
552
+
553
+ def set_input_embeddings(self, value):
554
+ self.embed_tokens = value
555
+
556
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
557
+ def _prepare_decoder_attention_mask(
558
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
559
+ ):
560
+ # create causal mask
561
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
562
+ combined_attention_mask = None
563
+ if input_shape[-1] > 1:
564
+ combined_attention_mask = _make_causal_mask(
565
+ input_shape,
566
+ inputs_embeds.dtype,
567
+ device=inputs_embeds.device,
568
+ past_key_values_length=past_key_values_length,
569
+ )
570
+
571
+ if attention_mask is not None:
572
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
573
+ expanded_attn_mask = _expand_mask(
574
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
575
+ ).to(inputs_embeds.device)
576
+ combined_attention_mask = (
577
+ expanded_attn_mask
578
+ if combined_attention_mask is None
579
+ else expanded_attn_mask + combined_attention_mask
580
+ )
581
+
582
+ return combined_attention_mask
583
+
584
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
585
+ def forward(
586
+ self,
587
+ input_ids: torch.LongTensor = None,
588
+ attention_mask: Optional[torch.Tensor] = None,
589
+ position_ids: Optional[torch.LongTensor] = None,
590
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
591
+ inputs_embeds: Optional[torch.FloatTensor] = None,
592
+ use_cache: Optional[bool] = None,
593
+ output_attentions: Optional[bool] = None,
594
+ output_hidden_states: Optional[bool] = None,
595
+ return_dict: Optional[bool] = None,
596
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
597
+ output_attentions = (
598
+ output_attentions
599
+ if output_attentions is not None
600
+ else self.config.output_attentions
601
+ )
602
+ output_hidden_states = (
603
+ output_hidden_states
604
+ if output_hidden_states is not None
605
+ else self.config.output_hidden_states
606
+ )
607
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
608
+
609
+ return_dict = (
610
+ return_dict if return_dict is not None else self.config.use_return_dict
611
+ )
612
+
613
+ # retrieve input_ids and inputs_embeds
614
+ if input_ids is not None and inputs_embeds is not None:
615
+ raise ValueError(
616
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
617
+ )
618
+ elif input_ids is not None:
619
+ batch_size, seq_length = input_ids.shape
620
+ elif inputs_embeds is not None:
621
+ batch_size, seq_length, _ = inputs_embeds.shape
622
+ else:
623
+ raise ValueError(
624
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
625
+ )
626
+
627
+ seq_length_with_past = seq_length
628
+ past_key_values_length = 0
629
+
630
+ if past_key_values is not None:
631
+ past_key_values_length = past_key_values[0][0].shape[2]
632
+ seq_length_with_past = seq_length_with_past + past_key_values_length
633
+
634
+ if position_ids is None:
635
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
636
+ position_ids = torch.arange(
637
+ past_key_values_length,
638
+ seq_length + past_key_values_length,
639
+ dtype=torch.long,
640
+ device=device,
641
+ )
642
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
643
+ else:
644
+ position_ids = position_ids.view(-1, seq_length).long()
645
+
646
+ if inputs_embeds is None:
647
+ inputs_embeds = self.embed_tokens(input_ids)
648
+ # embed positions
649
+ if attention_mask is None:
650
+ attention_mask = torch.ones(
651
+ (batch_size, seq_length_with_past),
652
+ dtype=torch.bool,
653
+ device=inputs_embeds.device,
654
+ )
655
+ attention_mask = self._prepare_decoder_attention_mask(
656
+ attention_mask,
657
+ (batch_size, seq_length),
658
+ inputs_embeds,
659
+ past_key_values_length,
660
+ )
661
+
662
+ hidden_states = inputs_embeds
663
+
664
+ if self.gradient_checkpointing and self.training:
665
+ if use_cache:
666
+ logger.warning_once(
667
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
668
+ )
669
+ use_cache = False
670
+
671
+ # decoder layers
672
+ all_hidden_states = () if output_hidden_states else None
673
+ all_self_attns = () if output_attentions else None
674
+ next_decoder_cache = () if use_cache else None
675
+
676
+ for idx, decoder_layer in enumerate(self.layers):
677
+ if output_hidden_states:
678
+ all_hidden_states += (hidden_states,)
679
+
680
+ past_key_value = (
681
+ past_key_values[idx] if past_key_values is not None else None
682
+ )
683
+
684
+ if self.gradient_checkpointing and self.training:
685
+
686
+ def create_custom_forward(module):
687
+ def custom_forward(*inputs):
688
+ # None for past_key_value
689
+ return module(*inputs, output_attentions, None)
690
+
691
+ return custom_forward
692
+
693
+ layer_outputs = torch.utils.checkpoint.checkpoint(
694
+ create_custom_forward(decoder_layer),
695
+ hidden_states,
696
+ attention_mask,
697
+ position_ids,
698
+ None,
699
+ )
700
+ else:
701
+ layer_outputs = decoder_layer(
702
+ hidden_states,
703
+ attention_mask=attention_mask,
704
+ position_ids=position_ids,
705
+ past_key_value=past_key_value,
706
+ output_attentions=output_attentions,
707
+ use_cache=use_cache,
708
+ )
709
+
710
+ hidden_states = layer_outputs[0]
711
+
712
+ if use_cache:
713
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
714
+
715
+ if output_attentions:
716
+ all_self_attns += (layer_outputs[1],)
717
+
718
+ hidden_states = self.norm(hidden_states)
719
+
720
+ # add hidden states from the last decoder layer
721
+ if output_hidden_states:
722
+ all_hidden_states += (hidden_states,)
723
+
724
+ next_cache = next_decoder_cache if use_cache else None
725
+ if not return_dict:
726
+ return tuple(
727
+ v
728
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
729
+ if v is not None
730
+ )
731
+ return BaseModelOutputWithPast(
732
+ last_hidden_state=hidden_states,
733
+ past_key_values=next_cache,
734
+ hidden_states=all_hidden_states,
735
+ attentions=all_self_attns,
736
+ )
737
+
738
+
739
+ class LlamaForCausalLM(LlamaPreTrainedModel):
740
+ def __init__(self, config):
741
+ super().__init__(config)
742
+ self.model = LlamaModel(config)
743
+
744
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
745
+
746
+ # Initialize weights and apply final processing
747
+ self.post_init()
748
+
749
+ def get_input_embeddings(self):
750
+ return self.model.embed_tokens
751
+
752
+ def set_input_embeddings(self, value):
753
+ self.model.embed_tokens = value
754
+
755
+ def get_output_embeddings(self):
756
+ return self.lm_head
757
+
758
+ def set_output_embeddings(self, new_embeddings):
759
+ self.lm_head = new_embeddings
760
+
761
+ def set_decoder(self, decoder):
762
+ self.model = decoder
763
+
764
+ def get_decoder(self):
765
+ return self.model
766
+
767
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
768
+ @replace_return_docstrings(
769
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
770
+ )
771
+ def forward(
772
+ self,
773
+ input_ids: torch.LongTensor = None,
774
+ attention_mask: Optional[torch.Tensor] = None,
775
+ position_ids: Optional[torch.LongTensor] = None,
776
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
777
+ inputs_embeds: Optional[torch.FloatTensor] = None,
778
+ labels: Optional[torch.LongTensor] = None,
779
+ use_cache: Optional[bool] = None,
780
+ output_attentions: Optional[bool] = None,
781
+ output_hidden_states: Optional[bool] = None,
782
+ return_dict: Optional[bool] = None,
783
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
784
+ r"""
785
+ Args:
786
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
787
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
788
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
789
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
790
+
791
+ Returns:
792
+
793
+ Example:
794
+
795
+ ```python
796
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
797
+
798
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
799
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
800
+
801
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
802
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
803
+
804
+ >>> # Generate
805
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
806
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
807
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
808
+ ```"""
809
+
810
+ output_attentions = (
811
+ output_attentions
812
+ if output_attentions is not None
813
+ else self.config.output_attentions
814
+ )
815
+ output_hidden_states = (
816
+ output_hidden_states
817
+ if output_hidden_states is not None
818
+ else self.config.output_hidden_states
819
+ )
820
+ return_dict = (
821
+ return_dict if return_dict is not None else self.config.use_return_dict
822
+ )
823
+
824
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
825
+ outputs = self.model(
826
+ input_ids=input_ids,
827
+ attention_mask=attention_mask,
828
+ position_ids=position_ids,
829
+ past_key_values=past_key_values,
830
+ inputs_embeds=inputs_embeds,
831
+ use_cache=use_cache,
832
+ output_attentions=output_attentions,
833
+ output_hidden_states=output_hidden_states,
834
+ return_dict=return_dict,
835
+ )
836
+
837
+ hidden_states = outputs[0]
838
+ logits = self.lm_head(hidden_states)
839
+
840
+ loss = None
841
+ if labels is not None:
842
+ # Shift so that tokens < n predict n
843
+ shift_logits = logits[..., :-1, :].contiguous()
844
+ shift_labels = labels[..., 1:].contiguous()
845
+ # Flatten the tokens
846
+ loss_fct = CrossEntropyLoss()
847
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
848
+ shift_labels = shift_labels.view(-1)
849
+ # Enable model parallelism
850
+ shift_labels = shift_labels.to(shift_logits.device)
851
+ loss = loss_fct(shift_logits, shift_labels)
852
+
853
+ if not return_dict:
854
+ output = (logits,) + outputs[1:]
855
+ return (loss,) + output if loss is not None else output
856
+
857
+ return CausalLMOutputWithPast(
858
+ loss=loss,
859
+ logits=logits,
860
+ past_key_values=outputs.past_key_values,
861
+ hidden_states=outputs.hidden_states,
862
+ attentions=outputs.attentions,
863
+ )
864
+
865
+ def prepare_inputs_for_generation(
866
+ self,
867
+ input_ids,
868
+ past_key_values=None,
869
+ attention_mask=None,
870
+ inputs_embeds=None,
871
+ **kwargs,
872
+ ):
873
+ if past_key_values:
874
+ input_ids = input_ids[:, -1:]
875
+
876
+ position_ids = kwargs.get("position_ids", None)
877
+ if attention_mask is not None and position_ids is None:
878
+ # create position_ids on the fly for batch generation
879
+ position_ids = attention_mask.long().cumsum(-1) - 1
880
+ position_ids.masked_fill_(attention_mask == 0, 1)
881
+ if past_key_values:
882
+ position_ids = position_ids[:, -1].unsqueeze(-1)
883
+
884
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
885
+ if inputs_embeds is not None and past_key_values is None:
886
+ model_inputs = {"inputs_embeds": inputs_embeds}
887
+ else:
888
+ model_inputs = {"input_ids": input_ids}
889
+
890
+ model_inputs.update(
891
+ {
892
+ "position_ids": position_ids,
893
+ "past_key_values": past_key_values,
894
+ "use_cache": kwargs.get("use_cache"),
895
+ "attention_mask": attention_mask,
896
+ }
897
+ )
898
+ return model_inputs
899
+
900
+ @staticmethod
901
+ def _reorder_cache(past_key_values, beam_idx):
902
+ reordered_past = ()
903
+ for layer_past in past_key_values:
904
+ reordered_past += (
905
+ tuple(
906
+ past_state.index_select(0, beam_idx) for past_state in layer_past
907
+ ),
908
+ )
909
+ return reordered_past
910
+
911
+
912
+ @add_start_docstrings(
913
+ """
914
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
915
+
916
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
917
+ (e.g. GPT-2) do.
918
+
919
+ Since it does classification on the last token, it requires to know the position of the last token. If a
920
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
921
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
922
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
923
+ each row of the batch).
924
+ """,
925
+ LLAMA_START_DOCSTRING,
926
+ )
927
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
928
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
929
+
930
+ def __init__(self, config):
931
+ super().__init__(config)
932
+ self.num_labels = config.num_labels
933
+ self.model = LlamaModel(config)
934
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
935
+
936
+ # Initialize weights and apply final processing
937
+ self.post_init()
938
+
939
+ def get_input_embeddings(self):
940
+ return self.model.embed_tokens
941
+
942
+ def set_input_embeddings(self, value):
943
+ self.model.embed_tokens = value
944
+
945
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
946
+ def forward(
947
+ self,
948
+ input_ids: torch.LongTensor = None,
949
+ attention_mask: Optional[torch.Tensor] = None,
950
+ position_ids: Optional[torch.LongTensor] = None,
951
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
952
+ inputs_embeds: Optional[torch.FloatTensor] = None,
953
+ labels: Optional[torch.LongTensor] = None,
954
+ use_cache: Optional[bool] = None,
955
+ output_attentions: Optional[bool] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ return_dict: Optional[bool] = None,
958
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
959
+ r"""
960
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
961
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
962
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
963
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
964
+ """
965
+ return_dict = (
966
+ return_dict if return_dict is not None else self.config.use_return_dict
967
+ )
968
+
969
+ transformer_outputs = self.model(
970
+ input_ids,
971
+ attention_mask=attention_mask,
972
+ position_ids=position_ids,
973
+ past_key_values=past_key_values,
974
+ inputs_embeds=inputs_embeds,
975
+ use_cache=use_cache,
976
+ output_attentions=output_attentions,
977
+ output_hidden_states=output_hidden_states,
978
+ return_dict=return_dict,
979
+ )
980
+ hidden_states = transformer_outputs[0]
981
+ logits = self.score(hidden_states)
982
+
983
+ if input_ids is not None:
984
+ batch_size = input_ids.shape[0]
985
+ else:
986
+ batch_size = inputs_embeds.shape[0]
987
+
988
+ if self.config.pad_token_id is None and batch_size != 1:
989
+ raise ValueError(
990
+ "Cannot handle batch sizes > 1 if no padding token is defined."
991
+ )
992
+ if self.config.pad_token_id is None:
993
+ sequence_lengths = -1
994
+ else:
995
+ if input_ids is not None:
996
+ sequence_lengths = (
997
+ torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
998
+ ).to(logits.device)
999
+ else:
1000
+ sequence_lengths = -1
1001
+
1002
+ pooled_logits = logits[
1003
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1004
+ ]
1005
+
1006
+ loss = None
1007
+ if labels is not None:
1008
+ labels = labels.to(logits.device)
1009
+ if self.config.problem_type is None:
1010
+ if self.num_labels == 1:
1011
+ self.config.problem_type = "regression"
1012
+ elif self.num_labels > 1 and (
1013
+ labels.dtype == torch.long or labels.dtype == torch.int
1014
+ ):
1015
+ self.config.problem_type = "single_label_classification"
1016
+ else:
1017
+ self.config.problem_type = "multi_label_classification"
1018
+
1019
+ if self.config.problem_type == "regression":
1020
+ loss_fct = MSELoss()
1021
+ if self.num_labels == 1:
1022
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1023
+ else:
1024
+ loss = loss_fct(pooled_logits, labels)
1025
+ elif self.config.problem_type == "single_label_classification":
1026
+ loss_fct = CrossEntropyLoss()
1027
+ loss = loss_fct(
1028
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1029
+ )
1030
+ elif self.config.problem_type == "multi_label_classification":
1031
+ loss_fct = BCEWithLogitsLoss()
1032
+ loss = loss_fct(pooled_logits, labels)
1033
+ if not return_dict:
1034
+ output = (pooled_logits,) + transformer_outputs[1:]
1035
+ return ((loss,) + output) if loss is not None else output
1036
+
1037
+ return SequenceClassifierOutputWithPast(
1038
+ loss=loss,
1039
+ logits=pooled_logits,
1040
+ past_key_values=transformer_outputs.past_key_values,
1041
+ hidden_states=transformer_outputs.hidden_states,
1042
+ attentions=transformer_outputs.attentions,
1043
+ )
models/tts/valle_v2.1/train.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import DictConfig, OmegaConf
2
+ from typing import Optional
3
+ import hydra
4
+ import os
5
+
6
+
7
+ def train(cfg):
8
+ trainer = hydra.utils.instantiate(cfg.trainer)
9
+ trainer.train_loop()
10
+
11
+
12
+ @hydra.main(version_base="1.3", config_path="./cfg", config_name="base.yaml")
13
+ def main(cfg: DictConfig) -> Optional[float]:
14
+ # train the model
15
+ train(cfg)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ main()
models/tts/valle_v2.1/valle_ar.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import os
11
+ import torch.nn as nn
12
+
13
+
14
+ class ValleAR(nn.Module):
15
+ def __init__(
16
+ self,
17
+ phone_vocab_size=256,
18
+ target_vocab_size=1024,
19
+ hidden_size=1024,
20
+ intermediate_size=4096,
21
+ num_hidden_layers=12,
22
+ num_attention_heads=16,
23
+ pad_token_id=1281,
24
+ bos_target_id=1282,
25
+ eos_target_id=1283,
26
+ bos_phone_id=1284,
27
+ eos_phone_id=1285,
28
+ use_input_embeds=False,
29
+ emb_dim=256,
30
+ **kwargs,
31
+ ):
32
+ super(ValleAR, self).__init__()
33
+ self.config = LlamaConfig(
34
+ vocab_size=phone_vocab_size + target_vocab_size + 10,
35
+ hidden_size=hidden_size,
36
+ intermediate_size=intermediate_size,
37
+ num_hidden_layers=num_hidden_layers,
38
+ num_attention_heads=num_attention_heads,
39
+ pad_token_id=pad_token_id,
40
+ bos_token_id=bos_target_id,
41
+ eos_token_id=eos_target_id,
42
+ )
43
+ self.phone_vocab_size = phone_vocab_size
44
+ self.target_vocab_size = target_vocab_size
45
+ self.pad_token_id = pad_token_id
46
+ self.bos_target_id = bos_target_id
47
+ self.eos_target_id = eos_target_id
48
+ self.bos_phone_id = bos_phone_id
49
+ self.eos_phone_id = eos_phone_id
50
+ self.model = LlamaForCausalLM(self.config)
51
+
52
+ self.use_input_embeds = use_input_embeds
53
+
54
+ # no input embedding is used to provide speaker information
55
+ if self.use_input_embeds:
56
+ self.emb_linear = nn.Linear(emb_dim, hidden_size)
57
+ self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
58
+ self.emb_linear.bias.data.zero_()
59
+
60
+ def forward(
61
+ self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None
62
+ ):
63
+ if input_embeds is not None:
64
+ input_embeds = self.emb_linear(input_embeds)
65
+ phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
66
+ phone_ids,
67
+ phone_mask,
68
+ self.eos_phone_id,
69
+ self.bos_phone_id,
70
+ self.pad_token_id,
71
+ )
72
+ target_ids, target_mask, target_label = self.add_target_eos_bos_label(
73
+ target_ids,
74
+ target_mask,
75
+ self.eos_target_id,
76
+ self.bos_target_id,
77
+ self.pad_token_id,
78
+ )
79
+ input_token_ids = torch.cat([phone_ids, target_ids], dim=-1)
80
+ attention_mask = torch.cat([phone_mask, target_mask], dim=-1)
81
+ # breakpoint()
82
+ if input_embeds is not None:
83
+ raise NotImplementedError
84
+ attention_mask = torch.cat(
85
+ [
86
+ torch.ones(
87
+ (input_embeds.shape[0], input_embeds.shape[1]),
88
+ dtype=attention_mask.dtype,
89
+ device=attention_mask.device,
90
+ ),
91
+ attention_mask,
92
+ ],
93
+ dim=-1,
94
+ )
95
+ labels = torch.cat([phone_label, target_label], dim=-1)
96
+ if input_embeds is not None:
97
+ raise NotImplementedError
98
+ labels = torch.cat(
99
+ [
100
+ -100
101
+ * torch.ones(
102
+ (input_embeds.shape[0], input_embeds.shape[1]),
103
+ dtype=labels.dtype,
104
+ device=labels.device,
105
+ ),
106
+ labels,
107
+ ],
108
+ dim=-1,
109
+ )
110
+
111
+ if input_embeds is not None:
112
+ raise NotImplementedError
113
+ inputs_embeds = torch.cat(
114
+ [input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
115
+ )
116
+ out = self.model(
117
+ inputs_embeds=inputs_embeds,
118
+ attention_mask=attention_mask,
119
+ labels=labels,
120
+ return_dict=True,
121
+ )
122
+ return out
123
+
124
+ out = self.model(
125
+ input_token_ids,
126
+ attention_mask=attention_mask,
127
+ labels=labels,
128
+ return_dict=True,
129
+ )
130
+
131
+ # calcualte top1, top5, top10 accuracy
132
+ logits = out.logits
133
+ logits = logits[:, -target_ids.shape[1] :]
134
+ top1_acc = logits.argmax(-1)[..., :-1] == target_ids[:, 1:]
135
+ top1_acc = (top1_acc * target_mask[..., :-1]).sum() / target_mask.sum()
136
+
137
+ top5_acc = torch.topk(logits[..., :-1, :], 5, dim=-1)[1]
138
+ top5_acc = top5_acc == target_ids[:, 1:].unsqueeze(-1)
139
+ top5_acc = (
140
+ top5_acc * target_mask[..., :-1].unsqueeze(-1)
141
+ ).sum() / target_mask.sum()
142
+
143
+ top10_acc = torch.topk(logits[..., :-1, :], 10, dim=-1)[1]
144
+ top10_acc = top10_acc == target_ids[:, 1:].unsqueeze(-1)
145
+ top10_acc = (
146
+ top10_acc * target_mask[..., :-1].unsqueeze(-1)
147
+ ).sum() / target_mask.sum()
148
+
149
+ out.top1_acc = top1_acc
150
+ out.top5_acc = top5_acc
151
+ out.top10_acc = top10_acc
152
+
153
+ return out
154
+
155
+ def add_phone_eos_bos_label(
156
+ self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
157
+ ):
158
+ # phone_ids: [B, T]
159
+ # phone_mask: [B, T]
160
+
161
+ phone_ids = phone_ids + self.target_vocab_size * phone_mask
162
+
163
+ phone_ids = phone_ids * phone_mask
164
+ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
165
+ 1 - phone_mask, (0, 1), value=1
166
+ ) # make pad token eos token, add eos token at the end
167
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
168
+ phone_ids = phone_ids * phone_mask + pad_token_id * (
169
+ 1 - phone_mask
170
+ ) # restore pad token ids
171
+ phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
172
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
173
+ phone_label = -100 * torch.ones_like(
174
+ phone_ids
175
+ ) # loss for entire phone is not computed (passed to llama)
176
+ return phone_ids, phone_mask, phone_label
177
+
178
+ def add_target_eos_bos_label(
179
+ self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id
180
+ ):
181
+ # target_ids: [B, T]
182
+ # target_mask: [B, T]
183
+ target_ids = target_ids * target_mask
184
+ target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad(
185
+ 1 - target_mask, (0, 1), value=1
186
+ )
187
+ target_mask = F.pad(target_mask, (1, 0), value=1)
188
+ target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask)
189
+ target_ids = F.pad(target_ids, (1, 0), value=target_bos_id)
190
+ target_mask = F.pad(target_mask, (1, 0), value=1)
191
+ target_label = target_ids * target_mask + (-100) * (
192
+ 1 - target_mask
193
+ ) # loss for target is computed on unmasked tokens
194
+ return target_ids, target_mask, target_label
195
+
196
+ def sample_hf(
197
+ self,
198
+ phone_ids, # the phones of prompt and target should be concatenated together
199
+ prompt_ids,
200
+ inputs_embeds=None,
201
+ max_length=2000,
202
+ temperature=1.0,
203
+ top_k=100,
204
+ top_p=0.9,
205
+ repeat_penalty=1.0,
206
+ num_beams=1,
207
+ ):
208
+ if inputs_embeds is not None:
209
+ inputs_embeds = self.emb_linear(inputs_embeds)
210
+ phone_mask = torch.ones_like(phone_ids)
211
+ prompt_mask = torch.ones_like(prompt_ids)
212
+ phone_ids, _, _ = self.add_phone_eos_bos_label(
213
+ phone_ids,
214
+ phone_mask,
215
+ self.eos_phone_id,
216
+ self.bos_phone_id,
217
+ self.pad_token_id,
218
+ )
219
+ prompt_ids, _, _ = self.add_target_eos_bos_label(
220
+ prompt_ids,
221
+ prompt_mask,
222
+ self.eos_target_id,
223
+ self.bos_target_id,
224
+ self.pad_token_id,
225
+ )
226
+ prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode
227
+
228
+ input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1)
229
+
230
+ if inputs_embeds is not None:
231
+ raise NotImplementedError
232
+ inputs_embeds = torch.cat(
233
+ [inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
234
+ )
235
+ generated_ids = self.model.generate(
236
+ inputs_embeds=inputs_embeds,
237
+ do_sample=True,
238
+ max_length=max_length,
239
+ pad_token_id=self.pad_token_id,
240
+ eos_token_id=self.eos_target_id,
241
+ temperature=temperature,
242
+ top_k=top_k,
243
+ top_p=top_p,
244
+ repetition_penalty=repeat_penalty,
245
+ )
246
+ gen_tokens = generated_ids[:, :-1]
247
+ return gen_tokens
248
+
249
+ input_length = input_token_ids.shape[1]
250
+ generated_ids = self.model.generate(
251
+ input_token_ids,
252
+ do_sample=True,
253
+ max_length=max_length,
254
+ pad_token_id=self.pad_token_id,
255
+ eos_token_id=self.eos_target_id,
256
+ temperature=temperature,
257
+ top_k=top_k,
258
+ top_p=top_p,
259
+ repetition_penalty=repeat_penalty,
260
+ num_beams=num_beams,
261
+ )
262
+
263
+ gen_tokens = generated_ids[:, input_length:-1]
264
+
265
+ return gen_tokens
266
+
267
+
268
+ def test():
269
+ model = ValleAR()
270
+
271
+ phone_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6]])
272
+ phone_mask = torch.LongTensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])
273
+ target_ids = torch.LongTensor([765, 234, 123, 234, 123, 599]).expand(2, -1)
274
+ target_mask = torch.LongTensor([1, 1, 1, 1, 0, 0]).expand(2, -1)
275
+
276
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
277
+
278
+ for i in range(15):
279
+ optimizer.zero_grad()
280
+ out = model(
281
+ phone_ids=phone_ids,
282
+ phone_mask=phone_mask,
283
+ target_ids=target_ids,
284
+ target_mask=target_mask,
285
+ )
286
+ loss = out.loss
287
+
288
+ loss.backward()
289
+
290
+ optimizer.step()
291
+
292
+ print(f"iter={i}, {loss}.")
293
+
294
+ phone_ids = torch.LongTensor([1, 2, 3]).reshape(1, -1)
295
+ target_ids = torch.LongTensor([765, 234]).reshape(1, -1)
296
+ sampled = model.sample_hf(phone_ids, target_ids)
297
+
298
+ breakpoint()
299
+
300
+
301
+ if __name__ == "__main__":
302
+ test()
models/tts/valle_v2.1/valle_ar_trainer.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import shutil
9
+ import torch
10
+ import time
11
+ from pathlib import Path
12
+ import torch
13
+ from tqdm import tqdm
14
+ import torch.nn as nn
15
+ from .base_trainer import BaseTrainer
16
+
17
+
18
+ def make_pad_mask(
19
+ lengths: torch.Tensor, max_len: int = 0, left_pad=False
20
+ ) -> torch.Tensor:
21
+ """
22
+ Args:
23
+ lengths:
24
+ A 1-D tensor containing sentence lengths.
25
+ max_len:
26
+ The length of masks.
27
+ left_pad:
28
+ A boolean indicating whether to left pad the mask.
29
+ Returns:
30
+ Return a 2-D bool tensor, where masked positions
31
+ are filled with `True` and non-masked positions are
32
+ filled with `False`.
33
+
34
+ >>> lengths = torch.tensor([1, 3, 2, 5])
35
+ >>> make_pad_mask(lengths)
36
+ tensor([[False, True, True, True, True],
37
+ [False, False, False, True, True],
38
+ [False, False, True, True, True],
39
+ [False, False, False, False, False]])
40
+ """
41
+ assert lengths.ndim == 1, lengths.ndim
42
+ max_len = max(max_len, lengths.max())
43
+ n = lengths.size(0)
44
+ seq_range = torch.arange(0, max_len, device=lengths.device)
45
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
46
+ mask = expaned_lengths >= lengths.unsqueeze(-1)
47
+
48
+ if left_pad:
49
+ mask = mask.flip(dims=[1])
50
+
51
+ return mask
52
+
53
+
54
+ class ValleARTrainer(BaseTrainer):
55
+ def __init__(self, args=None, cfg=None):
56
+ super().__init__(args, cfg)
57
+ if self.cfg.use_speechtokenizer:
58
+ from models.codec.speechtokenizer.model import SpeechTokenizer
59
+
60
+ config_path = "./ckpts/speechtokenizer_hubert_avg/config.json"
61
+ ckpt_path = "./ckpts/speechtokenizer_hubert_avg/SpeechTokenizer.pt"
62
+ assert os.path.isfile(
63
+ config_path
64
+ ), f"codec model {config_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts"
65
+ assert os.path.isfile(
66
+ ckpt_path
67
+ ), f"codec model {ckpt_path} not found! Download with huggingface-cli download fnlp/SpeechTokenizer speechtokenizer_hubert_avg/SpeechTokenizer.pt speechtokenizer_hubert_avg/config.json --local-dir ckpts"
68
+ self.codec_encoder = SpeechTokenizer.load_from_checkpoint(
69
+ config_path, ckpt_path
70
+ )
71
+ self.codec_encoder.eval()
72
+ self.codec_encoder.to(self.accelerator.device)
73
+ print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}")
74
+ else:
75
+ from encodec import EncodecModel
76
+
77
+ with self.accelerator.main_process_first():
78
+ self.codec_encoder = EncodecModel.encodec_model_24khz()
79
+ self.codec_encoder.set_target_bandwidth(6.0)
80
+ self.codec_encoder.to(self.accelerator.device)
81
+ self.codec_decoder = None
82
+ print("Loaded EncodecModel")
83
+ self.top1_accuracies = []
84
+ self.top5_accuracies = []
85
+ self.top10_accuracies = []
86
+
87
+ if hasattr(self.cfg, "flatten_first_2_layers"):
88
+ self.flatten_first_2_layers = self.cfg.flatten_first_2_layers
89
+ print("flattened:", self.flatten_first_2_layers)
90
+ else:
91
+ self.flatten_first_2_layers = False
92
+
93
+ if hasattr(self.cfg, "num_prediction_heads"):
94
+ self.num_prediction_heads = self.cfg.num_prediction_heads
95
+ print("num_prediction_heads:", self.num_prediction_heads)
96
+
97
+ def _accelerator_prepare(self):
98
+ # if self.accelerator.is_main_process:
99
+ # breakpoint()
100
+ # self.accelerator.wait_for_everyone()
101
+
102
+ (
103
+ self.model,
104
+ self.optimizer,
105
+ ) = self.accelerator.prepare(
106
+ self.model,
107
+ self.optimizer,
108
+ )
109
+
110
+ def _build_criterion(self):
111
+ pass # loss is directly returned from model
112
+
113
+ def _build_scheduler(self):
114
+ from transformers import (
115
+ get_cosine_schedule_with_warmup,
116
+ get_constant_schedule_with_warmup,
117
+ )
118
+
119
+ return get_cosine_schedule_with_warmup(
120
+ self.optimizer,
121
+ num_warmup_steps=self.cfg.train.scheduler.warmup_steps,
122
+ num_training_steps=self.cfg.train.scheduler.total_steps,
123
+ )
124
+
125
+ def _build_model(self):
126
+ if hasattr(self.cfg.model, "num_prediction_heads"):
127
+ from .valle_ar_multihead import ValleAR
128
+ else:
129
+ from .valle_ar import ValleAR
130
+ return ValleAR(**self.cfg.model)
131
+
132
+ def _train_step(self, batch):
133
+ # inference codec
134
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
135
+ speech: [B, T]
136
+ speech_len: [B]
137
+ phone_ids: [B, T]
138
+ phone_lens: [B]
139
+ """
140
+ device = self.accelerator.device
141
+ for k, v in batch.items():
142
+ if isinstance(v, torch.Tensor):
143
+ batch[k] = v.to(device)
144
+ with torch.no_grad():
145
+ if self.cfg.use_speechtokenizer:
146
+ # Extract discrete codes from SpeechTokenizer
147
+ vq_id = self.codec_encoder.encode(
148
+ batch["speech"].unsqueeze(1)
149
+ ) # [B,1,T] -> (n_q, B, T)
150
+ else:
151
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
152
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
153
+ 0, 1
154
+ )
155
+
156
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
157
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
158
+ # vq_id: [8, B, T//320]
159
+ if self.flatten_first_2_layers:
160
+ first_layer = vq_id[0]
161
+ second_layer = vq_id[1]
162
+ # flatten the first two layers
163
+ batch["speech"] = torch.stack(
164
+ [first_layer, second_layer], dim=-1
165
+ ).flatten(-2, -1)
166
+ batch["speech_len"] = batch["speech_len"] // 160
167
+ elif hasattr(self.cfg.model, "num_prediction_heads"):
168
+ batch["speech"] = vq_id[:2] # first two layers
169
+ batch["speech_len"] = (
170
+ batch["speech_len"] // 320
171
+ ) # our codec downsamples 320x
172
+ else:
173
+ batch["speech"] = vq_id[0] # use first layer
174
+ batch["speech_len"] = (
175
+ batch["speech_len"] // 320
176
+ ) # our codec downsamples 320x
177
+ assert batch["speech_len"].max() <= batch["speech"].shape[-1]
178
+
179
+ phone_mask = 1 - make_pad_mask(
180
+ batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False
181
+ ).to(torch.long)
182
+ speech_mask = 1 - make_pad_mask(
183
+ batch["speech_len"], max_len=batch["speech"].size(1)
184
+ ).to(torch.long)
185
+
186
+ out = self.model(
187
+ phone_ids=batch["phone_ids"],
188
+ phone_mask=phone_mask,
189
+ target_ids=batch["speech"],
190
+ target_mask=speech_mask,
191
+ )
192
+ loss = out.loss
193
+ # if self.accelerator.is_main_process:
194
+ # print(loss)
195
+ # if hasattr(out, 'top1_acc'):
196
+ # self.top1_accuracies.append(out.top1_acc)
197
+ # self.top5_accuracies.append(out.top5_acc)
198
+ # self.top10_accuracies.append(out.top10_acc)
199
+ # print(f'avgs: top1: {sum(self.top1_accuracies)/len(self.top1_accuracies)}, top5: {sum(self.top5_accuracies)/len(self.top5_accuracies)}, top10: {sum(self.top10_accuracies)/len(self.top10_accuracies)}')
200
+ # breakpoint()
201
+ return loss
202
+
203
+ ##########add your own dataloader to the trainer#############
204
+ def _build_dataloader(self):
205
+ from torch.utils.data import ConcatDataset, DataLoader
206
+
207
+ if self.cfg.train.dataset.name == "emilia":
208
+ from .emilia_dataset import EmiliaDataset as VALLEDataset
209
+
210
+ train_dataset = VALLEDataset()
211
+ elif self.cfg.train.dataset.name == "mls":
212
+ from .mls_dataset import VALLEDataset as VALLEDataset
213
+
214
+ train_dataset = VALLEDataset(self.cfg.dataset, resample_to_24k=False)
215
+ elif self.cfg.train.dataset.name == "libritts":
216
+ from .libritts_dataset import VALLEDataset as VALLEDataset
217
+
218
+ train_dataset = VALLEDataset(self.cfg.dataset)
219
+
220
+ from .valle_collator import VALLECollator
221
+ import numpy as np
222
+
223
+ print("length of train_dataset:", len(train_dataset))
224
+
225
+ collator = VALLECollator()
226
+
227
+ if self.cfg.train.dataset.use_dynamic_batchsize:
228
+ if self.accelerator.is_main_process:
229
+ self.logger.info("Use Dynamic Batchsize......")
230
+ from .mls_dataset import batch_by_size
231
+
232
+ batch_sampler = batch_by_size(
233
+ train_dataset.num_frame_indices,
234
+ train_dataset.get_num_frames,
235
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
236
+ max_sentences=self.cfg.train.max_sentences
237
+ * self.accelerator.num_processes,
238
+ required_batch_size_multiple=self.accelerator.num_processes,
239
+ )
240
+ np.random.shuffle(batch_sampler)
241
+ print(batch_sampler[0])
242
+ batches = [
243
+ x[
244
+ self.accelerator.local_process_index :: self.accelerator.num_processes
245
+ ]
246
+ for x in batch_sampler
247
+ if len(x) % self.accelerator.num_processes == 0
248
+ ]
249
+ from models.base.base_sampler import VariableSampler
250
+
251
+ train_loader = DataLoader(
252
+ train_dataset,
253
+ collate_fn=collator,
254
+ num_workers=self.cfg.train.dataloader.num_worker,
255
+ batch_sampler=VariableSampler(
256
+ batches, drop_last=True, use_random_sampler=True
257
+ ),
258
+ pin_memory=self.cfg.train.dataloader.pin_memory,
259
+ persistent_workers=self.cfg.train.dataloader.persistent_workers,
260
+ prefetch_factor=4,
261
+ )
262
+ print(
263
+ f"process {self.accelerator.local_process_index} has {len(batches)} batches"
264
+ )
265
+ self.accelerator.wait_for_everyone()
266
+
267
+ else:
268
+ sampler = torch.utils.data.distributed.DistributedSampler(
269
+ train_dataset,
270
+ num_replicas=self.accelerator.num_processes,
271
+ rank=self.accelerator.local_process_index,
272
+ shuffle=True,
273
+ )
274
+ train_loader = DataLoader(
275
+ train_dataset,
276
+ batch_size=self.cfg.train.batch_size,
277
+ num_workers=self.cfg.train.dataloader.num_worker,
278
+ pin_memory=self.cfg.train.dataloader.pin_memory,
279
+ collate_fn=collator,
280
+ sampler=sampler,
281
+ )
282
+ print(
283
+ f"process {self.accelerator.local_process_index} has {len(train_loader)} batches"
284
+ )
285
+
286
+ return train_loader, None
287
+
288
+ def _test_step(self, batch):
289
+ # inference codec
290
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
291
+ speech: [B, T]
292
+ speech_len: [B]
293
+ phone_ids: [B, T]
294
+ phone_lens: [B]
295
+ """
296
+ import torchaudio
297
+
298
+ device = self.accelerator.device
299
+ for k, v in batch.items():
300
+ if isinstance(v, torch.Tensor):
301
+ batch[k] = v.to(device)
302
+ with torch.no_grad():
303
+ if self.cfg.use_speechtokenizer:
304
+ # Extract discrete codes from SpeechTokenizer
305
+ vq_id = self.codec_encoder.encode(
306
+ batch["speech"].unsqueeze(1)
307
+ ) # [B,1,T] -> (n_q, B, T)
308
+ else:
309
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
310
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
311
+ 0, 1
312
+ )
313
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
314
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
315
+ # vq_id: [8, B, T//200]
316
+
317
+ # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1)
318
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
319
+ # recovered_audio.shape: torch.Size([1, 1, 50200])
320
+
321
+ if self.flatten_first_2_layers:
322
+ first_layer = vq_id[0]
323
+ second_layer = vq_id[1]
324
+ # flatten the first two layers
325
+ batch["speech"] = torch.stack(
326
+ [first_layer, second_layer], dim=-1
327
+ ).flatten(-2, -1)
328
+ batch["speech_len"] = batch["speech_len"] // 160
329
+ elif hasattr(self.cfg.model, "num_prediction_heads"):
330
+ batch["speech"] = vq_id[:2] # first two layers
331
+ batch["speech_len"] = (
332
+ batch["speech_len"] // 320
333
+ ) # our codec downsamples 320x
334
+ else:
335
+ batch["speech"] = vq_id[0] # use first layer
336
+ batch["speech_len"] = (
337
+ batch["speech_len"] // 320
338
+ ) # our codec downsamples 320x
339
+
340
+ # save gt
341
+ breakpoint()
342
+ recovered_audio = self.codec_encoder.decode(vq_id[:1, :1])
343
+ # recovered_audio = self.codec_encoder.decode([(vq_id[:1].transpose(0,1), None)])
344
+ torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000)
345
+ out_vq_ids = self.model.sample_hf(
346
+ batch["phone_ids"][:1, ...], batch["speech"][:1, :225], temperature=0.9
347
+ )
348
+ # out_vq_ids = torch.cat([batch['speech'][:1, :225], out_vq_ids[:1, ...]], dim=1)
349
+
350
+ # reconstruct form tokens
351
+ recovered_audio = self.codec_encoder.decode(out_vq_ids.unsqueeze(0))
352
+ # recovered_audio = self.codec_encoder.decode([(out_vq_ids, None)])
353
+ torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000)
354
+ breakpoint()
355
+ print()
356
+
357
+ @torch.inference_mode()
358
+ def _valid_epoch(self):
359
+ r"""Testing epoch. Should return average loss of a batch (sample) over
360
+ one epoch. See ``train_loop`` for usage.
361
+ """
362
+ epoch_sum_loss = 0.0
363
+ return epoch_sum_loss
364
+
365
+ def _inference(self):
366
+ pass
367
+
368
+ def test_loop(self):
369
+ self.model.eval()
370
+ for batch in self.train_dataloader:
371
+ self._test_step(batch)
models/tts/valle_v2.1/valle_collator.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ class VALLECollator:
11
+ def __init__(self, cfg=None):
12
+ self.cfg = cfg
13
+
14
+ def __call__(self, batch):
15
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
16
+ speech: [B, T]
17
+ speech_len: [B]
18
+ phone_ids: [B, T]
19
+ phone_lens: [B]
20
+ """
21
+ assert len(batch) != 0, "batch is empty before None checking"
22
+ batch = [b for b in batch if b is not None]
23
+ assert len(batch) != 0, "batch is empty after None checking"
24
+ packed_batch_features = {}
25
+
26
+ # Function to handle tensor copying
27
+ def process_tensor(data, dtype=torch.float32):
28
+ if isinstance(data, torch.Tensor):
29
+ return data.detach()
30
+ else:
31
+ return torch.tensor(data, dtype=dtype)
32
+
33
+ # Process 'speech' data
34
+ speeches = [process_tensor(b["speech"]) for b in batch]
35
+ packed_batch_features["speech_len"] = torch.tensor(
36
+ [len(s) for s in speeches], dtype=torch.long
37
+ )
38
+ packed_batch_features["speech"] = pad_sequence(
39
+ speeches, batch_first=True, padding_value=0
40
+ )
41
+
42
+ # right-padding 'phone' data
43
+ phones = [process_tensor(b["phone"], dtype=torch.long) for b in batch]
44
+ packed_batch_features["phone_lens"] = torch.tensor(
45
+ [len(phone) for phone in phones], dtype=torch.long
46
+ )
47
+ packed_batch_features["phone_ids"] = pad_sequence(
48
+ phones, batch_first=True, padding_value=0
49
+ )
50
+
51
+ # # Process 'phone' data, with left padding
52
+ # phones = [process_tensor(b['phone'], dtype=torch.long).flip(0) for b in batch] # first reverse the whole sequence
53
+ # packed_batch_features['phone_lens'] = torch.tensor([len(phone) for phone in phones], dtype=torch.long)
54
+ # packed_batch_features['phone_ids'] = pad_sequence(phones, batch_first=True, padding_value=0) # do the right padding
55
+ # packed_batch_features['phone_ids'] = packed_batch_features['phone_ids'].flip(1) # flip back to original order (left padding)
56
+
57
+ return packed_batch_features
models/tts/valle_v2.1/valle_inference.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torchaudio
8
+
9
+
10
+ class ValleInference(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ use_vocos=False,
14
+ use_speechtokenizer=True,
15
+ ar_path=None,
16
+ nar_path=None,
17
+ speechtokenizer_path=None,
18
+ device="cuda",
19
+ ):
20
+ super().__init__()
21
+
22
+ self.device = device
23
+
24
+ # prepare pretrained VALLE AR model
25
+ from .valle_ar import ValleAR
26
+
27
+ self.ar_model = ValleAR(
28
+ phone_vocab_size=300,
29
+ target_vocab_size=1024,
30
+ pad_token_id=1324,
31
+ bos_target_id=1325,
32
+ eos_target_id=1326,
33
+ bos_phone_id=1327,
34
+ eos_phone_id=1328,
35
+ bos_prompt_id=1329,
36
+ eos_prompt_id=1330,
37
+ num_hidden_layers=16,
38
+ )
39
+ # change the following path to your trained model path
40
+ assert ar_path is not None
41
+ self.ar_model.load_state_dict(torch.load(ar_path, map_location="cpu"))
42
+ self.ar_model.eval().to(self.device)
43
+
44
+ # prepare pretrained VALLE NAR model
45
+ from .valle_nar import ValleNAR
46
+
47
+ self.nar_model = ValleNAR(
48
+ phone_vocab_size=300,
49
+ target_vocab_size=1024,
50
+ pad_token_id=1324,
51
+ bos_target_id=1325,
52
+ eos_target_id=1326,
53
+ bos_phone_id=1327,
54
+ eos_phone_id=1328,
55
+ bos_prompt_id=1329,
56
+ eos_prompt_id=1330,
57
+ num_hidden_layers=16,
58
+ )
59
+ assert nar_path is not None
60
+ self.nar_model.load_state_dict(torch.load(nar_path, map_location="cpu"))
61
+ self.nar_model.eval().to(self.device)
62
+
63
+ # prepare codec encoder
64
+ assert not (
65
+ use_speechtokenizer and use_vocos
66
+ ), "Only one of use_speechtokenizer and use_vocos can be True"
67
+ self.use_speechtokenizer = use_speechtokenizer
68
+ if use_speechtokenizer:
69
+ from models.codec.speechtokenizer.model import SpeechTokenizer
70
+
71
+ # download from https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg
72
+ config_path = speechtokenizer_path + "/config.json"
73
+ ckpt_path = speechtokenizer_path + "/SpeechTokenizer.pt"
74
+ self.codec_encoder = SpeechTokenizer.load_from_checkpoint(
75
+ config_path, ckpt_path
76
+ )
77
+ self.codec_encoder.eval()
78
+ self.codec_encoder.to(device)
79
+ print(f"Loaded SpeechTokenizer from {config_path} and {ckpt_path}")
80
+ else:
81
+ # use Encodec
82
+ from encodec import EncodecModel
83
+
84
+ self.codec_encoder = EncodecModel.encodec_model_24khz()
85
+ self.codec_encoder.set_target_bandwidth(6.0)
86
+ self.codec_encoder.to(self.device)
87
+ if use_vocos:
88
+ from vocos import Vocos
89
+
90
+ self.codec_decoder = Vocos.from_pretrained(
91
+ "charactr/vocos-encodec-24khz"
92
+ )
93
+ self.codec_decoder.to(self.device)
94
+ print("Loaded Vocos")
95
+ print("Loaded EncodecModel")
96
+
97
+ self.use_vocos = use_vocos
98
+
99
+ def decode(self, vq_ids):
100
+ """vq_ids.shape: [8, B, T],
101
+ returns: [B, 1, T]"""
102
+ if self.use_speechtokenizer:
103
+ # infer speechtokenizer
104
+ return self.codec_encoder.decode(vq_ids) # [B, 1, T]
105
+ else:
106
+ if not self.use_vocos:
107
+ # vocos decoder
108
+ return self.codec_encoder.decode([(vq_ids.transpose(0, 1), None)])
109
+ else:
110
+ # encodec decoder
111
+ features = self.codec_decoder.codes_to_features(vq_ids.squeeze(1))
112
+ bandwidth_id = torch.tensor([2], device=vq_ids.device)
113
+ return self.codec_decoder.decode(
114
+ features, bandwidth_id=bandwidth_id
115
+ ).unsqueeze(0)
116
+
117
+ def forward(self, batch, chunk_configs: list, return_prompt=False, prompt_len=None):
118
+ """batch: dict(
119
+ speech: [B, T]
120
+ phone_ids: [B, T]
121
+ )
122
+ returns: [B, 1, T] audio
123
+ """
124
+ if prompt_len is None:
125
+ prompt_len = 100000 # no prompt length limiting
126
+ for k, v in batch.items():
127
+ if isinstance(v, torch.Tensor):
128
+ batch[k] = v.to(self.device)
129
+ with torch.no_grad():
130
+ if self.use_speechtokenizer:
131
+ vq_id = self.codec_encoder.encode(
132
+ batch["speech"].unsqueeze(1)
133
+ ) # [B,1,T] -> (n_q, B, T)
134
+ else:
135
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
136
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
137
+ 0, 1
138
+ )
139
+
140
+ # typically we only require one config in the chunk,
141
+ # but we can also use multiple configs to, for example, use different sampling temperature at different positions
142
+ for chunk in chunk_configs:
143
+ ar_vq_ids = self.ar_model.sample_hf(
144
+ batch["phone_ids"],
145
+ vq_id[0, :, :prompt_len],
146
+ top_p=chunk["top_p"],
147
+ top_k=chunk["top_k"],
148
+ temperature=chunk["temperature"],
149
+ num_beams=chunk["num_beams"],
150
+ repeat_penalty=chunk["repeat_penalty"],
151
+ max_length=chunk["max_length"],
152
+ )
153
+ # recovered_audio_ar = self.decode(ar_vq_ids.unsqueeze(0))
154
+ # torchaudio.save('recovered_audio_ar.wav', recovered_audio_ar[0].cpu(), 24000)
155
+
156
+ nar_vq_ids = self.nar_model.sample_hf(
157
+ phone_ids=batch["phone_ids"],
158
+ prompt_ids=vq_id[:, :, :prompt_len],
159
+ first_stage_ids=ar_vq_ids,
160
+ # first_stage_ids=vq_id[0, :, prompt_len:],
161
+ )
162
+
163
+ if return_prompt:
164
+ nar_vq_ids = torch.cat(
165
+ [vq_id[..., :prompt_len], nar_vq_ids], dim=-1
166
+ )
167
+
168
+ recovered_audio = self.decode(nar_vq_ids)
169
+ return recovered_audio # [B, 1, T]
models/tts/valle_v2.1/valle_nar.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import os
11
+ import torch.nn as nn
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
15
+
16
+ NUM_QUANTIZERS = 8 # number of quantizers in total, currently assumes first layer AR.
17
+ START_QUANTIZATION_LAYER = 1 # start quantization layer
18
+ END_QUANTIZATION_LAYER = 7 # end quantization layer
19
+
20
+
21
+ class LlamaAdaptiveRMSNorm(nn.Module):
22
+ def __init__(self, hidden_size=1024, eps=1e-9, dim_cond=1024):
23
+ super().__init__()
24
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
25
+ nn.init.normal_(self.to_weight.weight, mean=0.0, std=0.02)
26
+ # nn.init.zeros_(self.to_weight.weight)
27
+ # nn.init.ones_(self.to_weight.bias)
28
+ self.variance_epsilon = eps
29
+ self._is_hf_initialized = True # disable automatic init
30
+
31
+ def forward(self, hidden_states, cond_embedding):
32
+ input_dtype = hidden_states.dtype
33
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
34
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
35
+
36
+ weight = self.to_weight(cond_embedding)
37
+
38
+ return (weight * hidden_states).to(input_dtype)
39
+
40
+
41
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
42
+ def __init__(self, config: LlamaConfig):
43
+ """Override to adaptive layer norm"""
44
+ super().__init__(config=config, layer_idx=0) # init attention, mlp, etc.
45
+ self.input_layernorm = LlamaAdaptiveRMSNorm(
46
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
47
+ )
48
+ self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
49
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
50
+ )
51
+
52
+ # add `cond` in forward function
53
+ def forward(
54
+ self,
55
+ hidden_states: torch.Tensor,
56
+ cond_embedding: torch.Tensor,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ position_ids: Optional[torch.LongTensor] = None,
59
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
60
+ output_attentions: Optional[bool] = False,
61
+ use_cache: Optional[bool] = False,
62
+ ) -> Tuple[
63
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
64
+ ]:
65
+ """
66
+ Args:
67
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
68
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
69
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
70
+ output_attentions (`bool`, *optional*):
71
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
72
+ returned tensors for more detail.
73
+ use_cache (`bool`, *optional*):
74
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
75
+ (see `past_key_values`).
76
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
77
+ """
78
+
79
+ residual = hidden_states
80
+
81
+ hidden_states = self.input_layernorm(
82
+ hidden_states, cond_embedding=cond_embedding
83
+ )
84
+
85
+ # Self Attention
86
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
87
+ hidden_states=hidden_states,
88
+ attention_mask=attention_mask,
89
+ position_ids=position_ids,
90
+ past_key_value=past_key_value,
91
+ output_attentions=output_attentions,
92
+ use_cache=use_cache,
93
+ )
94
+ hidden_states = residual + hidden_states
95
+
96
+ # Fully Connected
97
+ residual = hidden_states
98
+ hidden_states = self.post_attention_layernorm(
99
+ hidden_states, cond_embedding=cond_embedding
100
+ )
101
+ hidden_states = self.mlp(hidden_states)
102
+ hidden_states = residual + hidden_states
103
+
104
+ outputs = (hidden_states,)
105
+
106
+ if output_attentions:
107
+ outputs += (self_attn_weights,)
108
+
109
+ if use_cache:
110
+ outputs += (present_key_value,)
111
+
112
+ return outputs
113
+
114
+
115
+ from transformers.models.llama.modeling_llama import BaseModelOutputWithPast
116
+
117
+
118
+ class MultiEmbedding(nn.Module):
119
+ """Embedding for multiple quantization layers, summing up the embeddings of each layer."""
120
+
121
+ def __init__(
122
+ self,
123
+ num_embeddings=1034,
124
+ embedding_dim=1024,
125
+ num_quantization_layers=NUM_QUANTIZERS,
126
+ ):
127
+ super().__init__()
128
+ self.embeddings = nn.ModuleList(
129
+ [
130
+ nn.Embedding(num_embeddings, embedding_dim)
131
+ for _ in range(num_quantization_layers)
132
+ ]
133
+ )
134
+
135
+ # initialize embeddings
136
+ for i in range(num_quantization_layers):
137
+ self.embeddings[i].weight.data.normal_(mean=0.0, std=0.02)
138
+ self._is_hf_initialized = True # disable automatic init
139
+
140
+ def forward(self, input_ids):
141
+ """Input: [num_quant, B, T] -> Output: [B, T, H]"""
142
+ num_quant, B, T = input_ids.shape
143
+ summed_embeddings = torch.zeros(
144
+ B, T, self.embeddings[0].embedding_dim, device=input_ids.device
145
+ )
146
+ for i in range(num_quant):
147
+ summed_embeddings += self.embeddings[i](input_ids[i])
148
+ return summed_embeddings
149
+
150
+
151
+ class LlammaNARModel(LlamaModel):
152
+ def __init__(self, config):
153
+ """Adding adaptive layer norm, conditional embeddings, and multi-level input embeddings to the decoder layer"""
154
+ super().__init__(config)
155
+ self.layers = nn.ModuleList(
156
+ [LlamaNARDecoderLayer(config) for _ in range(config.num_hidden_layers)]
157
+ )
158
+ self.norm = LlamaAdaptiveRMSNorm(
159
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
160
+ )
161
+
162
+ self.embed_cond = nn.Embedding(
163
+ NUM_QUANTIZERS, config.hidden_size
164
+ ) # 7 quantization layers
165
+
166
+ for layer in self.layers:
167
+ layer.input_layernorm = LlamaAdaptiveRMSNorm(
168
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
169
+ )
170
+ layer.post_attention_layernorm = LlamaAdaptiveRMSNorm(
171
+ config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
172
+ )
173
+
174
+ self.post_init()
175
+
176
+ def _prepare_decoder_attention_mask(
177
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
178
+ ):
179
+ # create noncausal mask
180
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
181
+ combined_attention_mask = None
182
+
183
+ def _expand_mask(
184
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
185
+ ):
186
+ """
187
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
188
+ """
189
+ bsz, src_len = mask.size()
190
+ tgt_len = tgt_len if tgt_len is not None else src_len
191
+
192
+ expanded_mask = (
193
+ mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
194
+ )
195
+
196
+ inverted_mask = 1.0 - expanded_mask
197
+
198
+ return inverted_mask.masked_fill(
199
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
200
+ )
201
+
202
+ if attention_mask is not None:
203
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
204
+ expanded_attn_mask = _expand_mask(
205
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
206
+ ).to(inputs_embeds.device)
207
+ combined_attention_mask = (
208
+ expanded_attn_mask
209
+ if combined_attention_mask is None
210
+ else expanded_attn_mask + combined_attention_mask
211
+ )
212
+
213
+ return combined_attention_mask
214
+
215
+ def forward(
216
+ self,
217
+ input_ids: torch.LongTensor = None, # [num_quant, B, T]
218
+ cond: torch.LongTensor = None, # index for conditional embeddings, [B]
219
+ attention_mask: Optional[torch.Tensor] = None,
220
+ position_ids: Optional[torch.LongTensor] = None,
221
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
222
+ inputs_embeds: Optional[torch.FloatTensor] = None,
223
+ use_cache: Optional[bool] = None,
224
+ output_attentions: Optional[bool] = None,
225
+ output_hidden_states: Optional[bool] = None,
226
+ return_dict: Optional[bool] = None,
227
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
228
+
229
+ # retrieve some shape info
230
+ batch_size, seq_length, _ = input_ids.shape
231
+
232
+ inputs_embeds = input_ids # [B, T, H]
233
+ # embed cond
234
+ cond_embedding = self.embed_cond(cond) # [B, H]
235
+
236
+ output_attentions = (
237
+ output_attentions
238
+ if output_attentions is not None
239
+ else self.config.output_attentions
240
+ )
241
+ output_hidden_states = (
242
+ output_hidden_states
243
+ if output_hidden_states is not None
244
+ else self.config.output_hidden_states
245
+ )
246
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
247
+
248
+ return_dict = (
249
+ return_dict if return_dict is not None else self.config.use_return_dict
250
+ )
251
+
252
+ seq_length_with_past = seq_length
253
+ past_key_values_length = 0
254
+
255
+ if past_key_values is not None:
256
+ past_key_values_length = past_key_values[0][0].shape[2]
257
+ seq_length_with_past = seq_length_with_past + past_key_values_length
258
+
259
+ if position_ids is None:
260
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
261
+ position_ids = torch.arange(
262
+ past_key_values_length,
263
+ seq_length + past_key_values_length,
264
+ dtype=torch.long,
265
+ device=device,
266
+ )
267
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
268
+ else:
269
+ position_ids = position_ids.view(-1, seq_length).long()
270
+
271
+ # embed positions
272
+ if attention_mask is None:
273
+ attention_mask = torch.ones(
274
+ (batch_size, seq_length_with_past),
275
+ dtype=torch.bool,
276
+ device=inputs_embeds.device,
277
+ )
278
+ attention_mask = self._prepare_decoder_attention_mask(
279
+ attention_mask,
280
+ (batch_size, seq_length),
281
+ inputs_embeds,
282
+ past_key_values_length,
283
+ )
284
+
285
+ hidden_states = inputs_embeds
286
+
287
+ if self.gradient_checkpointing and self.training:
288
+ if use_cache:
289
+ use_cache = False
290
+
291
+ # decoder layers
292
+ all_hidden_states = () if output_hidden_states else None
293
+ all_self_attns = () if output_attentions else None
294
+ next_decoder_cache = () if use_cache else None
295
+
296
+ for idx, decoder_layer in enumerate(self.layers):
297
+ if output_hidden_states:
298
+ all_hidden_states += (hidden_states,)
299
+
300
+ past_key_value = (
301
+ past_key_values[idx] if past_key_values is not None else None
302
+ )
303
+
304
+ if self.gradient_checkpointing and self.training:
305
+ raise NotImplementedError
306
+
307
+ def create_custom_forward(module):
308
+ def custom_forward(*inputs):
309
+ # None for past_key_value
310
+ return module(*inputs, output_attentions, None)
311
+
312
+ return custom_forward
313
+
314
+ layer_outputs = torch.utils.checkpoint.checkpoint(
315
+ create_custom_forward(decoder_layer),
316
+ hidden_states,
317
+ attention_mask,
318
+ position_ids,
319
+ None,
320
+ )
321
+ else:
322
+ layer_outputs = decoder_layer(
323
+ hidden_states,
324
+ attention_mask=attention_mask,
325
+ position_ids=position_ids,
326
+ past_key_value=past_key_value,
327
+ output_attentions=output_attentions,
328
+ use_cache=use_cache,
329
+ cond_embedding=cond_embedding, # using cond embed
330
+ )
331
+
332
+ hidden_states = layer_outputs[0]
333
+
334
+ if use_cache:
335
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
336
+
337
+ if output_attentions:
338
+ all_self_attns += (layer_outputs[1],)
339
+
340
+ hidden_states = self.norm(hidden_states, cond_embedding=cond_embedding)
341
+
342
+ # add hidden states from the last decoder layer
343
+ if output_hidden_states:
344
+ all_hidden_states += (hidden_states,)
345
+
346
+ next_cache = next_decoder_cache if use_cache else None
347
+ if not return_dict:
348
+ return tuple(
349
+ v
350
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
351
+ if v is not None
352
+ )
353
+ return BaseModelOutputWithPast(
354
+ last_hidden_state=hidden_states,
355
+ past_key_values=next_cache,
356
+ hidden_states=all_hidden_states,
357
+ attentions=all_self_attns,
358
+ )
359
+
360
+
361
+ from transformers.models.llama.modeling_llama import LlamaPreTrainedModel
362
+ from transformers.models.llama.modeling_llama import CrossEntropyLoss
363
+ from easydict import EasyDict as edict
364
+
365
+
366
+ class LlamaForNARModeling(LlamaPreTrainedModel):
367
+ def __init__(self, config):
368
+ super().__init__(config)
369
+ self.model = LlammaNARModel(config)
370
+
371
+ self.lm_head = nn.ModuleList(
372
+ [
373
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
374
+ for i in range(END_QUANTIZATION_LAYER - START_QUANTIZATION_LAYER + 1)
375
+ ]
376
+ )
377
+
378
+ # Initialize weights and apply final processing
379
+ self.post_init()
380
+
381
+ def forward(
382
+ self,
383
+ cond: torch.LongTensor, # added
384
+ prediction_target: torch.LongTensor = None, # added. No shifting. -100 means no loss
385
+ input_ids: torch.LongTensor = None, # expect an embedding, [B, T, H]
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ position_ids: Optional[torch.LongTensor] = None,
388
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
389
+ inputs_embeds: Optional[torch.FloatTensor] = None,
390
+ # labels: Optional[torch.LongTensor] = None,
391
+ use_cache: Optional[bool] = None,
392
+ output_attentions: Optional[bool] = None,
393
+ output_hidden_states: Optional[bool] = None,
394
+ return_dict: Optional[bool] = None,
395
+ ):
396
+ """Prediction target: [B, T]"""
397
+ output_attentions = (
398
+ output_attentions
399
+ if output_attentions is not None
400
+ else self.config.output_attentions
401
+ )
402
+ output_hidden_states = (
403
+ output_hidden_states
404
+ if output_hidden_states is not None
405
+ else self.config.output_hidden_states
406
+ )
407
+ return_dict = (
408
+ return_dict if return_dict is not None else self.config.use_return_dict
409
+ )
410
+
411
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
412
+ outputs = self.model(
413
+ cond=cond, # added
414
+ input_ids=input_ids,
415
+ attention_mask=attention_mask,
416
+ position_ids=position_ids,
417
+ past_key_values=past_key_values,
418
+ inputs_embeds=inputs_embeds,
419
+ use_cache=use_cache,
420
+ output_attentions=output_attentions,
421
+ output_hidden_states=output_hidden_states,
422
+ return_dict=return_dict,
423
+ )
424
+
425
+ hidden_states = outputs[0]
426
+ logits = self.lm_head[cond - START_QUANTIZATION_LAYER](hidden_states)
427
+
428
+ loss = None
429
+ loss_fct = CrossEntropyLoss()
430
+
431
+ if prediction_target is not None:
432
+ # calculate loss if prediction_target is provided
433
+ logits_tmp = logits.view(-1, logits.size(-1))
434
+ prediction_target = prediction_target.view(-1)
435
+ loss = loss_fct(logits_tmp, prediction_target)
436
+
437
+ return edict(
438
+ loss=loss,
439
+ logits=logits,
440
+ )
441
+
442
+
443
+ class ValleNAR(nn.Module):
444
+ def __init__(
445
+ self,
446
+ phone_vocab_size=256,
447
+ target_vocab_size=1024,
448
+ hidden_size=1024,
449
+ intermediate_size=4096,
450
+ num_hidden_layers=12,
451
+ num_attention_heads=16,
452
+ pad_token_id=1024 + 256,
453
+ bos_target_id=1282,
454
+ eos_target_id=1283,
455
+ bos_phone_id=1284,
456
+ eos_phone_id=1285,
457
+ bos_prompt_id=1286,
458
+ eos_prompt_id=1287,
459
+ use_input_embeds=False,
460
+ emb_dim=256,
461
+ ):
462
+ super(ValleNAR, self).__init__()
463
+ self.config = LlamaConfig(
464
+ vocab_size=phone_vocab_size + target_vocab_size + 10,
465
+ hidden_size=hidden_size,
466
+ intermediate_size=intermediate_size,
467
+ num_hidden_layers=num_hidden_layers,
468
+ num_attention_heads=num_attention_heads,
469
+ pad_token_id=pad_token_id,
470
+ bos_token_id=bos_target_id,
471
+ eos_token_id=eos_target_id,
472
+ use_cache=False,
473
+ )
474
+ self.phone_vocab_size = phone_vocab_size
475
+ self.target_vocab_size = target_vocab_size
476
+ self.pad_token_id = pad_token_id
477
+ self.bos_target_id = bos_target_id
478
+ self.eos_target_id = eos_target_id
479
+ self.bos_phone_id = bos_phone_id
480
+ self.eos_phone_id = eos_phone_id
481
+ self.bos_prompt_id = bos_prompt_id
482
+ self.eos_prompt_id = eos_prompt_id
483
+ self.model = LlamaForNARModeling(self.config)
484
+
485
+ self.use_input_embeds = use_input_embeds
486
+
487
+ self.phone_embedder = nn.Embedding(
488
+ self.phone_vocab_size + 10, hidden_size
489
+ ) # use phone_embedder to embed all eos, bos tokens
490
+ self.prompt_embedder = MultiEmbedding(
491
+ num_embeddings=self.target_vocab_size,
492
+ embedding_dim=hidden_size,
493
+ num_quantization_layers=NUM_QUANTIZERS,
494
+ )
495
+ self.phone_embedder.weight.data.normal_(mean=0.0, std=0.02)
496
+
497
+ # use linear mask schedule when training
498
+ # another option is uniform
499
+ self.mask_layer_schedule = "uniform"
500
+
501
+ # no input embedding is used to provide speaker information
502
+ if self.use_input_embeds:
503
+ self.emb_linear = nn.Linear(emb_dim, hidden_size)
504
+ self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
505
+ self.emb_linear.bias.data.zero_()
506
+
507
+ def forward(
508
+ self,
509
+ phone_ids,
510
+ phone_mask,
511
+ target_ids,
512
+ target_mask,
513
+ target_quantization_layer=None,
514
+ prompt_len=None,
515
+ dropout=0.0,
516
+ ):
517
+ """
518
+ phone_ids: [B, T]
519
+ phone_mask: [B, T]
520
+ target_ids: [8,B,T]
521
+ target_mask: [B, T]
522
+ dropout: rate of dropping out the target tokens
523
+ """
524
+ assert (target_ids < 1024).all(), "target_ids should be less than 1024"
525
+ phone_ids = phone_ids + self.target_vocab_size
526
+ phone_ids = phone_ids * phone_mask + (1 - phone_mask) * self.pad_token_id
527
+ # assert (phone_ids >= 1024).all(), "phone_ids should be greater than 1024"
528
+ # phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
529
+ # phone_ids,
530
+ # phone_mask,
531
+ # self.eos_phone_id,
532
+ # self.bos_phone_id,
533
+ # self.pad_token_id,
534
+ # )
535
+ phone_label = -100 * (1 - phone_mask)
536
+ # get phone embedding
537
+ phone_embedding = self.phone_embedder(
538
+ phone_ids - self.target_vocab_size
539
+ ) # [B, T, H]
540
+
541
+ if prompt_len is not None:
542
+ assert not self.training # inference stage fix prompt len to input
543
+ NUM_PROMPT_TOKENS = prompt_len
544
+ else:
545
+ assert self.training
546
+ # randomly select a prompt length
547
+ assert self.training # randomize prompt len in training
548
+ NUM_PROMPT_TOKENS = np.random.randint(
549
+ min(target_ids.shape[-1] // 4, 5), target_ids.shape[-1] // 2
550
+ )
551
+
552
+ # extract 8-level prompts
553
+ prompt_tokens = target_ids[:, :, :NUM_PROMPT_TOKENS] # [Q, B, T]
554
+ prompt_mask = torch.ones_like(prompt_tokens[0])
555
+ prompt_label = -100 * prompt_mask
556
+ # get prompt embedding
557
+ prompt_embedding = self.prompt_embedder(prompt_tokens) # [B, T, H]
558
+
559
+ # randomly select a target qnt layer to predict
560
+ # total quant layer is 0 to 7
561
+ if target_quantization_layer is None:
562
+ if self.mask_layer_schedule == "linear":
563
+ weights = torch.tensor(
564
+ [
565
+ NUM_QUANTIZERS - i
566
+ for i in range(
567
+ START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
568
+ )
569
+ ]
570
+ )
571
+ weights = weights / weights.sum()
572
+ mask_layer = (
573
+ torch.multinomial(weights, 1, replacement=True)
574
+ + START_QUANTIZATION_LAYER
575
+ )
576
+ assert (
577
+ mask_layer >= START_QUANTIZATION_LAYER
578
+ and mask_layer <= END_QUANTIZATION_LAYER
579
+ )
580
+ target_quantization_layer = mask_layer.item()
581
+ elif self.mask_layer_schedule == "cosine":
582
+ weights = torch.tensor(
583
+ [
584
+ np.cos(i / NUM_QUANTIZERS * np.pi / 2)
585
+ for i in range(
586
+ START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
587
+ )
588
+ ]
589
+ )
590
+ weights = weights / weights.sum()
591
+ mask_layer = (
592
+ torch.multinomial(weights, 1, replacement=True)
593
+ + START_QUANTIZATION_LAYER
594
+ )
595
+ assert (
596
+ mask_layer >= START_QUANTIZATION_LAYER
597
+ and mask_layer <= END_QUANTIZATION_LAYER
598
+ )
599
+ target_quantization_layer = mask_layer.item()
600
+ breakpoint()
601
+ elif self.mask_layer_schedule == "uniform":
602
+ target_quantization_layer = np.random.randint(
603
+ START_QUANTIZATION_LAYER, END_QUANTIZATION_LAYER + 1
604
+ )
605
+
606
+ # print(f'target layer: {target_quantization_layer}')
607
+ # prompt of the target part
608
+ target_prompt_ids = target_ids[
609
+ :target_quantization_layer, :, NUM_PROMPT_TOKENS:
610
+ ]
611
+
612
+ def randomly_set_elements(tensor, fraction, value):
613
+ """
614
+ Randomly set a fraction of the elements in a tensor to a specific value.
615
+
616
+ Args:
617
+ tensor (torch.Tensor): The input tensor.
618
+ fraction (float): The fraction of elements to set to the specified value (between 0 and 1).
619
+ value (float or int): The value to set the elements to.
620
+
621
+ Returns:
622
+ torch.Tensor: The tensor with some elements set to the specified value.
623
+ """
624
+ # Create a mask with the same shape as the tensor
625
+ mask = torch.rand_like(tensor, dtype=torch.float32) < fraction
626
+ # Clone the tensor to avoid modifying the original tensor
627
+ result_tensor = tensor.clone()
628
+ # Set the elements where the mask is True to the specified value
629
+ result_tensor[mask] = value
630
+ return result_tensor
631
+
632
+ if dropout != 0.0:
633
+ target_prompt_ids = randomly_set_elements(
634
+ target_prompt_ids, dropout, self.target_vocab_size
635
+ )
636
+
637
+ target_embedding = self.prompt_embedder(target_prompt_ids)
638
+
639
+ # mask of the target part
640
+ target_mask = target_mask[:, NUM_PROMPT_TOKENS:]
641
+
642
+ target_labels = target_ids[
643
+ target_quantization_layer, :, NUM_PROMPT_TOKENS:
644
+ ] * target_mask + (-100 * (1 - target_mask))
645
+
646
+ # input embeddings
647
+ input_embeddings = torch.cat(
648
+ [phone_embedding, prompt_embedding, target_embedding], dim=1
649
+ )
650
+ input_mask = torch.cat([phone_mask, prompt_mask, target_mask], dim=1) # [B, T]
651
+ prediction_target = torch.cat(
652
+ [phone_label, prompt_label, target_labels], dim=1
653
+ ) # [B, T]
654
+
655
+ out = self.model(
656
+ cond=torch.tensor(
657
+ target_quantization_layer,
658
+ device=prediction_target.device,
659
+ dtype=torch.long,
660
+ ),
661
+ input_ids=input_embeddings,
662
+ prediction_target=prediction_target,
663
+ attention_mask=input_mask,
664
+ return_dict=True,
665
+ )
666
+ logits = out.logits[:, -target_embedding.shape[1] :, :]
667
+ targets = prediction_target[..., -target_embedding.shape[1] :]
668
+ top1_acc = logits.argmax(-1) == targets
669
+ top1_acc = (top1_acc * target_mask).sum() / target_mask.sum()
670
+
671
+ top5_acc = (logits.topk(5, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
672
+ top5_acc = (top5_acc * target_mask).sum() / target_mask.sum()
673
+
674
+ top10_acc = (logits.topk(10, dim=-1).indices == targets.unsqueeze(-1)).any(-1)
675
+ top10_acc = (top10_acc * target_mask).sum() / target_mask.sum()
676
+
677
+ out.target_quantization_layer = target_quantization_layer
678
+ out.top1_acc = top1_acc
679
+ out.top5_acc = top5_acc
680
+ out.top10_acc = top10_acc
681
+
682
+ return out
683
+
684
+ def add_phone_eos_bos_label(
685
+ self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
686
+ ):
687
+ # phone_ids: [B, T]
688
+ # phone_mask: [B, T]
689
+
690
+ phone_ids = phone_ids + self.target_vocab_size * phone_mask
691
+
692
+ phone_ids = phone_ids * phone_mask
693
+ phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
694
+ 1 - phone_mask, (0, 1), value=1
695
+ ) # make pad token eos token, add eos token at the end
696
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
697
+ phone_ids = phone_ids * phone_mask + pad_token_id * (
698
+ 1 - phone_mask
699
+ ) # restore pad token ids
700
+ phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
701
+ phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
702
+ phone_label = -100 * torch.ones_like(
703
+ phone_ids
704
+ ) # loss for entire phone is not computed (passed to llama)
705
+ return phone_ids, phone_mask, phone_label
706
+
707
+ @torch.no_grad()
708
+ def sample_hf(
709
+ self,
710
+ phone_ids, # [B, T]
711
+ prompt_ids, # [8, B, T]
712
+ first_stage_ids, # [B, T]
713
+ top_k=50,
714
+ top_p=1,
715
+ temperature=1.1,
716
+ first_stage_ids_gt=None, # [Q, B, T]
717
+ first_stage_ids_gt_end_layer=None, # 2 to 8
718
+ ):
719
+ """
720
+ phone_ids: [B, T]
721
+ prompt_ids: [8, B, T]
722
+ first_stage_ids: [B, T] result from first quant layer. Should be continuation of prompt_ids
723
+ """
724
+ phone_mask = torch.ones_like(phone_ids, dtype=torch.long)
725
+
726
+ assert prompt_ids.shape[-1] >= 5, "prompt_ids should have at least 5 tokens"
727
+ target_ids = torch.cat(
728
+ [prompt_ids, first_stage_ids.expand(prompt_ids.shape[0], -1, -1)], dim=-1
729
+ )
730
+ target_mask = torch.ones_like(target_ids[0], dtype=torch.long)
731
+
732
+ if first_stage_ids_gt is not None:
733
+ target_ids[
734
+ :first_stage_ids_gt_end_layer, :, -first_stage_ids_gt.shape[-1] :
735
+ ] = first_stage_ids_gt[:first_stage_ids_gt_end_layer]
736
+
737
+ gen_len = first_stage_ids.shape[-1]
738
+
739
+ start_qnt_layer = 1
740
+ if first_stage_ids_gt_end_layer is not None:
741
+ start_qnt_layer = first_stage_ids_gt_end_layer
742
+ for qnt_level in range(start_qnt_layer, 8):
743
+ out = self.forward(
744
+ phone_ids=phone_ids,
745
+ phone_mask=phone_mask,
746
+ target_ids=target_ids,
747
+ target_mask=target_mask,
748
+ target_quantization_layer=qnt_level,
749
+ prompt_len=prompt_ids.shape[-1],
750
+ )
751
+ logits = out.logits
752
+ gen_tokens = torch.argmax(logits, dim=-1).reshape(-1)[
753
+ -gen_len:
754
+ ] # [T], generated tokens in this level
755
+
756
+ # overwrite the target_ids with the generated tokens
757
+ target_ids[qnt_level, :, -gen_len:] = gen_tokens
758
+
759
+ return target_ids[:, :, -gen_len:]
760
+
761
+
762
+ def test():
763
+ model = ValleNAR().cuda()
764
+
765
+ phone_ids = torch.LongTensor([1, 2, 3, 4, 5]).reshape(1, -1).cuda()
766
+ phone_mask = torch.LongTensor([1, 1, 1, 1, 1]).reshape(1, -1).cuda()
767
+ target_ids = torch.randint(high=1024, size=(8, 1, 250), dtype=torch.long).cuda()
768
+ target_mask = torch.ones(1, 250, dtype=torch.long).cuda()
769
+ optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
770
+
771
+ for i in range(200):
772
+ optimizer.zero_grad()
773
+ out = model(
774
+ phone_ids=phone_ids,
775
+ phone_mask=phone_mask,
776
+ target_ids=target_ids,
777
+ target_mask=target_mask,
778
+ # target_quantization_layer=1+i%6,
779
+ )
780
+ loss = out.loss
781
+
782
+ loss.backward()
783
+
784
+ optimizer.step()
785
+
786
+ print(f"iter={i}, {loss}.")
787
+ target_ids_short = target_ids[:, :, :240]
788
+
789
+ model.eval()
790
+ sampled = model.sample_hf(
791
+ phone_ids, prompt_ids=target_ids_short, first_stage_ids=target_ids[0, :, 240:]
792
+ )
793
+
794
+ print(target_ids[:, :, -10:])
795
+ print(sampled)
796
+
797
+ print((sampled == target_ids[:, :, -10:]).all())
798
+
799
+
800
+ if __name__ == "__main__":
801
+ test()
models/tts/valle_v2.1/valle_nar_trainer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torchaudio
8
+ import numpy as np
9
+ import time
10
+ from .valle_ar_trainer import ValleARTrainer, make_pad_mask
11
+
12
+
13
+ class ValleNARTrainer(ValleARTrainer):
14
+ def __init__(self, args=None, cfg=None):
15
+ super().__init__(args, cfg)
16
+ print("simple NAR")
17
+ self.top1_accuracies = {
18
+ 1: [],
19
+ 2: [],
20
+ 3: [],
21
+ 4: [],
22
+ 5: [],
23
+ 6: [],
24
+ 7: [],
25
+ }
26
+ self.top5_accuracies = {
27
+ 1: [],
28
+ 2: [],
29
+ 3: [],
30
+ 4: [],
31
+ 5: [],
32
+ 6: [],
33
+ 7: [],
34
+ }
35
+ self.top10_accuracies = {
36
+ 1: [],
37
+ 2: [],
38
+ 3: [],
39
+ 4: [],
40
+ 5: [],
41
+ 6: [],
42
+ 7: [],
43
+ }
44
+
45
+ def _build_model(self):
46
+ from .valle_nar import ValleNAR
47
+
48
+ return ValleNAR(**self.cfg.model)
49
+
50
+ def _train_step(self, batch):
51
+ # inference codec
52
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
53
+ speech: [B, T]
54
+ speech_len: [B]
55
+ phone_ids: [B, T]
56
+ phone_lens: [B]
57
+ """
58
+ device = self.accelerator.device
59
+ for k, v in batch.items():
60
+ if isinstance(v, torch.Tensor):
61
+ batch[k] = v.to(device)
62
+
63
+ with torch.no_grad():
64
+ if self.cfg.use_speechtokenizer:
65
+ # Extract discrete codes from SpeechTokenizer
66
+ # 16k
67
+ vq_id = self.codec_encoder.encode(
68
+ batch["speech"].unsqueeze(1)
69
+ ) # [B,T] -> (n_q, B, T)
70
+ # RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens
71
+ # RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer
72
+ # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
73
+ # wav = self.codec_encoder.decode(vq_id)
74
+ # torchaudio.save('a.wav', wav[0].cpu(), 16000)
75
+
76
+ # # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers
77
+ # wav = model.decode(codes[i: (j + 1)], st=i)
78
+ else:
79
+ # using encodec, 24k
80
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
81
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
82
+ 0, 1
83
+ )
84
+
85
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
86
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
87
+ # vq_id: [8, B, T//320]
88
+ batch["speech"] = vq_id
89
+ batch["speech_len"] = batch["speech_len"] // 320 # our codec downsamples 320x
90
+ assert batch["speech_len"].max() <= batch["speech"].shape[-1]
91
+
92
+ phone_mask = 1 - make_pad_mask(
93
+ batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False
94
+ ).to(torch.long)
95
+ speech_mask = 1 - make_pad_mask(
96
+ batch["speech_len"], max_len=batch["speech"].size(-1)
97
+ ).to(torch.long)
98
+
99
+ np.random.seed(int(time.time()) - 5 * self.accelerator.process_index)
100
+
101
+ if hasattr(self.cfg.train, "dropout"):
102
+ dropout = self.cfg.train.dropout
103
+ else:
104
+ dropout = 0.0
105
+
106
+ out = self.model(
107
+ phone_ids=batch["phone_ids"],
108
+ phone_mask=phone_mask,
109
+ target_ids=batch["speech"],
110
+ target_mask=speech_mask,
111
+ dropout=dropout,
112
+ )
113
+ loss = out.loss
114
+
115
+ self.accelerator.log(
116
+ {f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc},
117
+ step=self.step,
118
+ )
119
+ self.accelerator.log(
120
+ {f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc},
121
+ step=self.step,
122
+ )
123
+ self.accelerator.log(
124
+ {f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc},
125
+ step=self.step,
126
+ )
127
+
128
+ # if hasattr(out, 'top1_acc'):
129
+ # idx = out.target_quantization_layer
130
+ # self.top1_accuracies[idx].append(out.top1_acc)
131
+ # self.top5_accuracies[idx].append(out.top5_acc)
132
+ # self.top10_accuracies[idx].append(out.top10_acc)
133
+ # if len(self.top1_accuracies[idx]) >= 160:
134
+ # breakpoint()
135
+ # if self.accelerator.is_main_process:
136
+ # print(loss)
137
+ return loss
138
+
139
+ def _test_step(self, batch):
140
+ # inference codec
141
+ """Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens')
142
+ speech: [B, T]
143
+ speech_len: [B]
144
+ phone_ids: [B, T]
145
+ phone_lens: [B]
146
+ """
147
+ import torchaudio
148
+
149
+ device = self.accelerator.device
150
+ for k, v in batch.items():
151
+ if isinstance(v, torch.Tensor):
152
+ batch[k] = v.to(device)
153
+ with torch.no_grad():
154
+ if self.cfg.use_speechtokenizer:
155
+ # Extract discrete codes from SpeechTokenizer
156
+ # 16k
157
+ vq_id = self.codec_encoder.encode(
158
+ batch["speech"].unsqueeze(1)
159
+ ) # [B,1,T] -> (n_q, B, T)
160
+ # Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding
161
+ # wav = self.codec_encoder.decode(vq_id)
162
+ # torchaudio.save('a.wav', wav[0].cpu(), 16000)
163
+
164
+ else:
165
+ vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1))
166
+ vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose(
167
+ 0, 1
168
+ )
169
+ # recovered_audio = self.codec_encoder.decode([(vq_id.transpose(0,1), None)])
170
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
171
+ # torchaudio.save('a.wav', recovered_audio[0], 16000)
172
+ # vq_id: [8, B, T//200]
173
+
174
+ # vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1)
175
+ # recovered_audio = self.codec_decoder(vq_emb, vq=False)
176
+ # recovered_audio.shape: torch.Size([1, 1, 50200])
177
+
178
+ batch["speech"] = vq_id
179
+
180
+ # save gt
181
+ if self.cfg.use_speechtokenizer:
182
+ recovered_audio = self.codec_encoder.decode(vq_id)
183
+ else:
184
+ recovered_audio = self.codec_encoder.decode(
185
+ [(vq_id.transpose(0, 1), None)]
186
+ )
187
+ torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000)
188
+ self.model.eval()
189
+ out_vq_ids = self.model.sample_hf(
190
+ phone_ids=batch["phone_ids"][:1],
191
+ prompt_ids=batch["speech"][:, :1, :150],
192
+ first_stage_ids=batch["speech"][0, :1, 150:],
193
+ )
194
+ # breakpoint()
195
+ # out_vq_ids = torch.cat([batch['speech'][:, :225], out_vq_ids], dim=1)
196
+
197
+ # reconstruct form tokens
198
+ if self.cfg.use_speechtokenizer:
199
+ recovered_audio = self.codec_encoder.decode(out_vq_ids)
200
+ else:
201
+ recovered_audio = self.codec_encoder.decode(
202
+ [(out_vq_ids.transpose(0, 1)[:1], None)]
203
+ )
204
+ torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000)
205
+ breakpoint()
utils/g2p/g2p.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ from typing import List, Pattern, Union
5
+ from phonemizer.utils import list2str, str2list
6
+ from phonemizer.backend import EspeakBackend
7
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
8
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
9
+ from phonemizer.punctuation import Punctuation
10
+ from phonemizer.separator import Separator
11
+ import jieba
12
+ import cn2an
13
+
14
+ # List of (Latin alphabet, bopomofo) pairs:
15
+ _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
16
+ ('a', 'ㄟˉ'),
17
+ ('b', 'ㄅㄧˋ'),
18
+ ('c', 'ㄙㄧˉ'),
19
+ ('d', 'ㄉㄧˋ'),
20
+ ('e', 'ㄧˋ'),
21
+ ('f', 'ㄝˊㄈㄨˋ'),
22
+ ('g', 'ㄐㄧˋ'),
23
+ ('h', 'ㄝˇㄑㄩˋ'),
24
+ ('i', 'ㄞˋ'),
25
+ ('j', 'ㄐㄟˋ'),
26
+ ('k', 'ㄎㄟˋ'),
27
+ ('l', 'ㄝˊㄛˋ'),
28
+ ('m', 'ㄝˊㄇㄨˋ'),
29
+ ('n', 'ㄣˉ'),
30
+ ('o', 'ㄡˉ'),
31
+ ('p', 'ㄆㄧˉ'),
32
+ ('q', 'ㄎㄧㄡˉ'),
33
+ ('r', 'ㄚˋ'),
34
+ ('s', 'ㄝˊㄙˋ'),
35
+ ('t', 'ㄊㄧˋ'),
36
+ ('u', 'ㄧㄡˉ'),
37
+ ('v', 'ㄨㄧˉ'),
38
+ ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
39
+ ('x', 'ㄝˉㄎㄨˋㄙˋ'),
40
+ ('y', 'ㄨㄞˋ'),
41
+ ('z', 'ㄗㄟˋ')
42
+ ]]
43
+
44
+ # List of (bopomofo, ipa) pairs:
45
+ _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
46
+ ('ㄅㄛ', 'p⁼wo'),
47
+ ('ㄆㄛ', 'pʰwo'),
48
+ ('ㄇㄛ', 'mwo'),
49
+ ('ㄈㄛ', 'fwo'),
50
+ ('ㄧㄢ', '|jɛn'),
51
+ ('ㄩㄢ', '|ɥæn'),
52
+ ('ㄧㄣ', '|in'),
53
+ ('ㄩㄣ', '|ɥn'),
54
+ ('ㄧㄥ', '|iŋ'),
55
+ ('ㄨㄥ', '|ʊŋ'),
56
+ ('ㄩㄥ', '|jʊŋ'),
57
+ # Add
58
+ ('ㄧㄚ', '|ia'),
59
+ ('ㄧㄝ', '|iɛ'),
60
+ ('ㄧㄠ', '|iɑʊ'),
61
+ ('ㄧㄡ', '|ioʊ'),
62
+ ('ㄧㄤ', '|iɑŋ'),
63
+ ('ㄨㄚ', '|ua'),
64
+ ('ㄨㄛ', '|uo'),
65
+ ('ㄨㄞ', '|uaɪ'),
66
+ ('ㄨㄟ', '|ueɪ'),
67
+ ('ㄨㄢ', '|uan'),
68
+ ('ㄨㄣ', '|uən'),
69
+ ('ㄨㄤ', '|uɑŋ'),
70
+ ('ㄩㄝ', '|ɥɛ'),
71
+ # End
72
+ ('ㄅ', 'p⁼'),
73
+ ('ㄆ', 'pʰ'),
74
+ ('ㄇ', 'm'),
75
+ ('ㄈ', 'f'),
76
+ ('ㄉ', 't⁼'),
77
+ ('ㄊ', 'tʰ'),
78
+ ('ㄋ', 'n'),
79
+ ('ㄌ', 'l'),
80
+ ('ㄍ', 'k⁼'),
81
+ ('ㄎ', 'kʰ'),
82
+ ('ㄏ', 'x'),
83
+ ('ㄐ', 'tʃ⁼'),
84
+ ('ㄑ', 'tʃʰ'),
85
+ ('ㄒ', 'ʃ'),
86
+ ('ㄓ', 'ts`⁼'),
87
+ ('ㄔ', 'ts`ʰ'),
88
+ ('ㄕ', 's`'),
89
+ ('ㄖ', 'ɹ`'),
90
+ ('ㄗ', 'ts⁼'),
91
+ ('ㄘ', 'tsʰ'),
92
+ ('ㄙ', '|s'),
93
+ ('ㄚ', '|a'),
94
+ ('ㄛ', '|o'),
95
+ ('ㄜ', '|ə'),
96
+ ('ㄝ', '|ɛ'),
97
+ ('ㄞ', '|aɪ'),
98
+ ('ㄟ', '|eɪ'),
99
+ ('ㄠ', '|ɑʊ'),
100
+ ('ㄡ', '|oʊ'),
101
+ ('ㄢ', '|an'),
102
+ ('ㄣ', '|ən'),
103
+ ('ㄤ', '|ɑŋ'),
104
+ ('ㄥ', '|əŋ'),
105
+ ('ㄦ', 'əɹ'),
106
+ ('ㄧ', '|i'),
107
+ ('ㄨ', '|u'),
108
+ ('ㄩ', '|ɥ'),
109
+ ('ˉ', '→|'),
110
+ ('ˊ', '↑|'),
111
+ ('ˇ', '↓↑|'),
112
+ ('ˋ', '↓|'),
113
+ ('˙', '|'),
114
+ (',', ','),
115
+ ('。', '.'),
116
+ ('!', '!'),
117
+ ('?', '?'),
118
+ ('—', '-'),
119
+ ]]
120
+
121
+ # Convert numbers to Chinese pronunciation
122
+ def number_to_chinese(text):
123
+ numbers = re.findall(r'\d+(?:\.?\d+)?', text)
124
+ for number in numbers:
125
+ text = text.replace(number, cn2an.an2cn(number), 1)
126
+ return text
127
+
128
+ # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
129
+ def chinese_to_bopomofo(text):
130
+ from pypinyin import lazy_pinyin, BOPOMOFO
131
+ text = text.replace('、', ',').replace(';', ',').replace(':', ',')
132
+ text = re.sub(r"\s+", "", text)
133
+ words = jieba.lcut(text, cut_all=False)
134
+ text = ''
135
+ for word in words:
136
+ bopomofos = lazy_pinyin(word, BOPOMOFO)
137
+ if not re.search('[\u4e00-\u9fff]', word):
138
+ text += word
139
+ continue
140
+ for i in range(len(bopomofos)):
141
+ bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
142
+ if text != '':
143
+ text += '|'
144
+ text += '|'.join(bopomofos)
145
+ return text
146
+
147
+ # Convert latin pronunciation to pinyin (bopomofo)
148
+ def latin_to_bopomofo(text):
149
+ for regex, replacement in _latin_to_bopomofo:
150
+ text = re.sub(regex, replacement, text)
151
+ return text
152
+
153
+ # Convert pinyin (bopomofo) to IPA
154
+ def bopomofo_to_ipa(text):
155
+ for regex, replacement in _bopomofo_to_ipa:
156
+ text = re.sub(regex, replacement, text)
157
+ return text
158
+
159
+ def _chinese_to_ipa(text):
160
+ text = number_to_chinese(text.strip())
161
+ text = chinese_to_bopomofo(text)
162
+ text = latin_to_bopomofo(text)
163
+ text = bopomofo_to_ipa(text)
164
+ text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
165
+ r'\1ɹ\2', text)
166
+ text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
167
+ text = re.sub(r'^\||[^\w\s_,\.\?!\|\'→↓↑⁼ʰ`]', '', text)
168
+ text = re.sub(r'([,.!?])', r'|\1', text)
169
+ text = re.sub(r'\|+', '|', text)
170
+ return text
171
+
172
+ # Convert Chinese to IPA
173
+ def chinese_to_ipa(text, text_tokenizer):
174
+ # phonemes = text_tokenizer(text.strip())
175
+ if type(text) == str:
176
+ return _chinese_to_ipa(text)
177
+ else:
178
+ result_ph = []
179
+ for t in text:
180
+ result_ph.append(_chinese_to_ipa(t))
181
+ return result_ph
182
+
183
+
184
+ _special_map = [
185
+ ('t|ɹ', 'tɹ'),
186
+ ('d|ɹ', 'dɹ'),
187
+ ('t|s', 'ts'),
188
+ ('d|z', 'dz'),
189
+ ('ɐ', 'ɚ'),
190
+ ('ᵻ', 'ɪ'),
191
+ ('əl', 'l'),
192
+ ('x', 'k'),
193
+ ('ɬ', 'l'),
194
+ ('ʔ', 't'),
195
+ ('n̩', 'n'),
196
+ ('oː|ɹ', 'oːɹ')
197
+ ]
198
+
199
+ # special map
200
+ def special_map(text):
201
+ for regex, replacement in _special_map:
202
+ regex = regex.replace("|", "\|")
203
+ while re.search(r'(^|[_|]){}([_|]|$)'.format(regex), text):
204
+ text = re.sub(r'(^|[_|]){}([_|]|$)'.format(regex), r'\1{}\2'.format(replacement), text)
205
+ text = re.sub(r'([,.!?])', r'|\1', text)
206
+ return text
207
+
208
+ def english_to_ipa(text, text_tokenizer):
209
+ # text = _english_to_ipa(text)
210
+ phonemes = text_tokenizer(text)
211
+ if type(text) == str:
212
+ return special_map(phonemes)
213
+ else:
214
+ result_ph = []
215
+ for phone in phonemes:
216
+ result_ph.append(special_map(phone))
217
+ return result_ph
218
+
219
+ def cjekfd_cleaners(text, language, text_tokenizers):
220
+
221
+ if language == 'zh':
222
+ return chinese_to_ipa(text, text_tokenizers['zh'])
223
+ elif language == 'en':
224
+ return english_to_ipa(text, text_tokenizers['en'])
225
+ else:
226
+ raise Exception('Unknown language: %s' % language)
227
+ return None
228
+
229
+ class TextTokenizer:
230
+ """Phonemize Text."""
231
+
232
+ def __init__(
233
+ self,
234
+ language="en-us",
235
+ backend="espeak",
236
+ separator=Separator(word="|_|", syllable="-", phone="|"),
237
+ preserve_punctuation=False,
238
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
239
+ with_stress: bool = False,
240
+ tie: Union[bool, str] = False,
241
+ language_switch: LanguageSwitch = "remove-flags",
242
+ words_mismatch: WordMismatch = "ignore",
243
+ ) -> None:
244
+ self.backend = EspeakBackend(
245
+ language,
246
+ punctuation_marks=punctuation_marks,
247
+ preserve_punctuation=preserve_punctuation,
248
+ with_stress=with_stress,
249
+ tie=tie,
250
+ language_switch=language_switch,
251
+ words_mismatch=words_mismatch,
252
+ )
253
+
254
+ self.separator = separator
255
+
256
+ def __call__(self, text, strip=True) -> List[str]:
257
+
258
+ text_type = type(text)
259
+ text = [re.sub(r'[^\w\s_,\.\?!\|\']', '', line.strip()) for line in str2list(text)]
260
+ phonemized = self.backend.phonemize(
261
+ text, separator=self.separator, strip=strip, njobs=1
262
+ )
263
+ if text_type == str:
264
+ return list2str(phonemized)
265
+ return phonemized
266
+
267
+ class PhonemeBpeTokenizer:
268
+
269
+ def __init__(self, vacab_path="./utils/g2p/mls_en.json"):
270
+ self.lang2backend = {
271
+ 'zh': "cmn",
272
+ 'ja': "ja",
273
+ "en": "en-us",
274
+ "fr": "fr-fr",
275
+ "ko": "ko",
276
+ "de": "de",
277
+ }
278
+ self.text_tokenizers = {}
279
+ self.int_text_tokenizers()
280
+
281
+ with open(vacab_path, 'r') as f:
282
+ json_data = f.read()
283
+ data = json.loads(json_data)
284
+ self.vocab = data['vocab']
285
+
286
+ def int_text_tokenizers(self):
287
+ for key, value in self.lang2backend.items():
288
+ self.text_tokenizers[key] = TextTokenizer(language=value)
289
+
290
+ def tokenize(self, text, language):
291
+
292
+ # 1. convert text to phoneme
293
+ phonemes = self._clean_text(text, language, ['cjekfd_cleaners'])
294
+ # print('clean text: ', phonemes)
295
+
296
+ # 2. tokenize phonemes
297
+ phoneme_tokens = self.phoneme2token(phonemes)
298
+
299
+ return phonemes, phoneme_tokens
300
+
301
+ def _clean_text(self, text, language, cleaner_names):
302
+
303
+ text = cjekfd_cleaners(text, language, self.text_tokenizers)
304
+ return text
305
+
306
+ def phoneme2token(self, phonemes):
307
+ tokens = []
308
+ if isinstance(phonemes, list):
309
+ for phone in phonemes:
310
+ phonemes_split = phone.split("|")
311
+ tokens.append([self.vocab[p] for p in phonemes_split if p in self.vocab])
312
+ else:
313
+ phonemes_split = phonemes.split("|")
314
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
315
+ return tokens
316
+
317
+ text_tokenizer = PhonemeBpeTokenizer()
318
+
319
+ def phonemizer_g2p(text, language):
320
+
321
+ return text_tokenizer.tokenize(text=text, language=language)
utils/g2p/mls_emilia.json ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[UNK]": 0,
3
+ "_": 1,
4
+ "b": 2,
5
+ "d": 3,
6
+ "f": 4,
7
+ "h": 5,
8
+ "i": 6,
9
+ "j": 7,
10
+ "k": 8,
11
+ "l": 9,
12
+ "m": 10,
13
+ "n": 11,
14
+ "p": 12,
15
+ "r": 13,
16
+ "s": 14,
17
+ "t": 15,
18
+ "v": 16,
19
+ "w": 17,
20
+ "x": 18,
21
+ "z": 19,
22
+ "æ": 20,
23
+ "ç": 21,
24
+ "ð": 22,
25
+ "ŋ": 23,
26
+ "ɐ": 24,
27
+ "ɔ": 25,
28
+ "ə": 26,
29
+ "ɚ": 27,
30
+ "ɛ": 28,
31
+ "ɡ": 29,
32
+ "ɪ": 30,
33
+ "ɬ": 31,
34
+ "ɹ": 32,
35
+ "ɾ": 33,
36
+ "ʃ": 34,
37
+ "ʊ": 35,
38
+ "ʌ": 36,
39
+ "ʒ": 37,
40
+ "ʔ": 38,
41
+ "θ": 39,
42
+ "ᵻ": 40,
43
+ "aɪ": 41,
44
+ "aʊ": 42,
45
+ "dʒ": 43,
46
+ "eɪ": 44,
47
+ "iə": 45,
48
+ "iː": 46,
49
+ "n̩": 47,
50
+ "oʊ": 48,
51
+ "oː": 49,
52
+ "tʃ": 50,
53
+ "uː": 51,
54
+ "ææ": 52,
55
+ "ɐɐ": 53,
56
+ "ɑː": 54,
57
+ "ɑ̃": 55,
58
+ "ɔɪ": 56,
59
+ "ɔː": 57,
60
+ "ɔ̃": 58,
61
+ "əl": 59,
62
+ "ɛɹ": 60,
63
+ "ɜː": 61,
64
+ "ɡʲ": 62,
65
+ "ɪɹ": 63,
66
+ "ʊɹ": 64,
67
+ "aɪə": 65,
68
+ "aɪɚ": 66,
69
+ "iːː": 67,
70
+ "oːɹ": 68,
71
+ "ɑːɹ": 69,
72
+ "ɔːɹ": 70,
73
+
74
+ "1": 71,
75
+ "a": 72,
76
+ "e": 73,
77
+ "o": 74,
78
+ "q": 75,
79
+ "u": 76,
80
+ "y": 77,
81
+ "ɑ": 78,
82
+ "ɒ": 79,
83
+ "ɕ": 80,
84
+ "ɣ": 81,
85
+ "ɫ": 82,
86
+ "ɯ": 83,
87
+ "ʐ": 84,
88
+ "ʲ": 85,
89
+ "a1": 86,
90
+ "a2": 87,
91
+ "a5": 88,
92
+ "ai": 89,
93
+ "aɜ": 90,
94
+ "aː": 91,
95
+ "ei": 92,
96
+ "eə": 93,
97
+ "i.": 94,
98
+ "i1": 95,
99
+ "i2": 96,
100
+ "i5": 97,
101
+ "io": 98,
102
+ "iɑ": 99,
103
+ "iɛ": 100,
104
+ "iɜ": 101,
105
+ "i̪": 102,
106
+ "kh": 103,
107
+ "nʲ": 104,
108
+ "o1": 105,
109
+ "o2": 106,
110
+ "o5": 107,
111
+ "ou": 108,
112
+ "oɜ": 109,
113
+ "ph": 110,
114
+ "s.": 111,
115
+ "th": 112,
116
+ "ts": 113,
117
+ "tɕ": 114,
118
+ "u1": 115,
119
+ "u2": 116,
120
+ "u5": 117,
121
+ "ua": 118,
122
+ "uo": 119,
123
+ "uə": 120,
124
+ "uɜ": 121,
125
+ "y1": 122,
126
+ "y2": 123,
127
+ "y5": 124,
128
+ "yu": 125,
129
+ "yæ": 126,
130
+ "yə": 127,
131
+ "yɛ": 128,
132
+ "yɜ": 129,
133
+ "ŋɜ": 130,
134
+ "ŋʲ": 131,
135
+ "ɑ1": 132,
136
+ "ɑ2": 133,
137
+ "ɑ5": 134,
138
+ "ɑu": 135,
139
+ "ɑɜ": 136,
140
+ "ɑʲ": 137,
141
+ "ə1": 138,
142
+ "ə2": 139,
143
+ "ə5": 140,
144
+ "ər": 141,
145
+ "əɜ": 142,
146
+ "əʊ": 143,
147
+ "ʊə": 144,
148
+ "ai1": 145,
149
+ "ai2": 146,
150
+ "ai5": 147,
151
+ "aiɜ": 148,
152
+ "ei1": 149,
153
+ "ei2": 150,
154
+ "ei5": 151,
155
+ "eiɜ": 152,
156
+ "i.1": 153,
157
+ "i.2": 154,
158
+ "i.5": 155,
159
+ "i.ɜ": 156,
160
+ "io5": 157,
161
+ "iou": 158,
162
+ "iɑ1": 159,
163
+ "iɑ2": 160,
164
+ "iɑ5": 161,
165
+ "iɑɜ": 162,
166
+ "iɛ1": 163,
167
+ "iɛ2": 164,
168
+ "iɛ5": 165,
169
+ "iɛɜ": 166,
170
+ "i̪1": 167,
171
+ "i̪2": 168,
172
+ "i̪5": 169,
173
+ "i̪ɜ": 170,
174
+ "onɡ": 171,
175
+ "ou1": 172,
176
+ "ou2": 173,
177
+ "ou5": 174,
178
+ "ouɜ": 175,
179
+ "ts.": 176,
180
+ "tsh": 177,
181
+ "tɕh": 178,
182
+ "u5ʲ": 179,
183
+ "ua1": 180,
184
+ "ua2": 181,
185
+ "ua5": 182,
186
+ "uai": 183,
187
+ "uaɜ": 184,
188
+ "uei": 185,
189
+ "uo1": 186,
190
+ "uo2": 187,
191
+ "uo5": 188,
192
+ "uoɜ": 189,
193
+ "uə1": 190,
194
+ "uə2": 191,
195
+ "uə5": 192,
196
+ "uəɜ": 193,
197
+ "yiɜ": 194,
198
+ "yu2": 195,
199
+ "yu5": 196,
200
+ "yæ2": 197,
201
+ "yæ5": 198,
202
+ "yæɜ": 199,
203
+ "yə2": 200,
204
+ "yə5": 201,
205
+ "yəɜ": 202,
206
+ "yɛ1": 203,
207
+ "yɛ2": 204,
208
+ "yɛ5": 205,
209
+ "yɛɜ": 206,
210
+ "ɑu1": 207,
211
+ "ɑu2": 208,
212
+ "ɑu5": 209,
213
+ "ɑuɜ": 210,
214
+ "ər1": 211,
215
+ "ər2": 212,
216
+ "ər5": 213,
217
+ "ərɜ": 214,
218
+ "əː1": 215,
219
+ "iou1": 216,
220
+ "iou2": 217,
221
+ "iou5": 218,
222
+ "iouɜ": 219,
223
+ "onɡ1": 220,
224
+ "onɡ2": 221,
225
+ "onɡ5": 222,
226
+ "onɡɜ": 223,
227
+ "ts.h": 224,
228
+ "uai2": 225,
229
+ "uai5": 226,
230
+ "uaiɜ": 227,
231
+ "uei1": 228,
232
+ "uei2": 229,
233
+ "uei5": 230,
234
+ "ueiɜ": 231,
235
+ "uoɜʲ": 232,
236
+ "yɛ5ʲ": 233,
237
+ "ɑu2ʲ": 234,
238
+
239
+ "2": 235,
240
+ "5": 236,
241
+ "ɜ": 237,
242
+ "ʂ": 238,
243
+ "dʑ": 239,
244
+ "iɪ": 240,
245
+ "uɪ": 241,
246
+ "xʲ": 242,
247
+ "ɑt": 243,
248
+ "ɛɜ": 244,
249
+ "ɛː": 245,
250
+ "ɪː": 246,
251
+ "phʲ": 247,
252
+ "ɑ5ʲ": 248,
253
+ "ɑuʲ": 249,
254
+ "ərə": 250,
255
+ "uozʰ": 251,
256
+ "ər1ʲ": 252,
257
+ "tɕhtɕh": 253,
258
+
259
+ "c": 254,
260
+ "ʋ": 255,
261
+ "ʍ": 256,
262
+ "ʑ": 257,
263
+ "ː": 258,
264
+ "aə": 259,
265
+ "eː": 260,
266
+ "hʲ": 261,
267
+ "iʊ": 262,
268
+ "kʲ": 263,
269
+ "lʲ": 264,
270
+ "oə": 265,
271
+ "oɪ": 266,
272
+ "oʲ": 267,
273
+ "pʲ": 268,
274
+ "sʲ": 269,
275
+ "u4": 270,
276
+ "uʲ": 271,
277
+ "yi": 272,
278
+ "yʲ": 273,
279
+ "ŋ2": 274,
280
+ "ŋ5": 275,
281
+ "ŋ̩": 276,
282
+ "ɑɪ": 277,
283
+ "ɑʊ": 278,
284
+ "ɕʲ": 279,
285
+ "ət": 280,
286
+ "əə": 281,
287
+ "əɪ": 282,
288
+ "əʲ": 283,
289
+ "ɛ1": 284,
290
+ "ɛ5": 285,
291
+ "aiə": 286,
292
+ "aiɪ": 287,
293
+ "azʰ": 288,
294
+ "eiə": 289,
295
+ "eiɪ": 290,
296
+ "eiʊ": 291,
297
+ "i.ə": 292,
298
+ "i.ɪ": 293,
299
+ "i.ʊ": 294,
300
+ "ioɜ": 295,
301
+ "izʰ": 296,
302
+ "iɑə": 297,
303
+ "iɑʊ": 298,
304
+ "iɑʲ": 299,
305
+ "iɛə": 300,
306
+ "iɛɪ": 301,
307
+ "iɛʊ": 302,
308
+ "i̪ə": 303,
309
+ "i̪ʊ": 304,
310
+ "khʲ": 305,
311
+ "ouʲ": 306,
312
+ "tsʲ": 307,
313
+ "u2ʲ": 308,
314
+ "uoɪ": 309,
315
+ "uzʰ": 310,
316
+ "uɜʲ": 311,
317
+ "yæɪ": 312,
318
+ "yəʊ": 313,
319
+ "ərt": 314,
320
+ "ərɪ": 315,
321
+ "ərʲ": 316,
322
+ "əːt": 317,
323
+ "iouə": 318,
324
+ "iouʊ": 319,
325
+ "iouʲ": 320,
326
+ "iɛzʰ": 321,
327
+ "onɡə": 322,
328
+ "onɡɪ": 323,
329
+ "onɡʊ": 324,
330
+ "ouzʰ": 325,
331
+ "uai1": 326,
332
+ "ueiɪ": 327,
333
+ "ɑuzʰ": 328,
334
+ "iouzʰ": 329
335
+ }
utils/g2p/mls_en.json ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab": {
3
+ ",": 0,
4
+ ".": 1,
5
+ "?": 2,
6
+ "!": 3,
7
+ "_": 4,
8
+ "iː": 5,
9
+ "ɪ": 6,
10
+ "ɜː": 7,
11
+ "ɚ": 8,
12
+ "oːɹ": 9,
13
+ "ɔː": 10,
14
+ "ɔːɹ": 11,
15
+ "ɑː": 12,
16
+ "uː": 13,
17
+ "ʊ": 14,
18
+ "ɑːɹ": 15,
19
+ "ʌ": 16,
20
+ "ɛ": 17,
21
+ "æ": 18,
22
+ "eɪ": 19,
23
+ "aɪ": 20,
24
+ "ɔɪ": 21,
25
+ "aʊ": 22,
26
+ "oʊ": 23,
27
+ "ɪɹ": 24,
28
+ "ɛɹ": 25,
29
+ "ʊɹ": 26,
30
+ "p": 27,
31
+ "b": 28,
32
+ "t": 29,
33
+ "d": 30,
34
+ "k": 31,
35
+ "ɡ": 32,
36
+ "f": 33,
37
+ "v": 34,
38
+ "θ": 35,
39
+ "ð": 36,
40
+ "s": 37,
41
+ "z": 38,
42
+ "ʃ": 39,
43
+ "ʒ": 40,
44
+ "h": 41,
45
+ "tʃ": 42,
46
+ "dʒ": 43,
47
+ "m": 44,
48
+ "n": 45,
49
+ "ŋ": 46,
50
+ "j": 47,
51
+ "w": 48,
52
+ "ɹ": 49,
53
+ "l": 50,
54
+ "tɹ": 51,
55
+ "dɹ": 52,
56
+ "ts": 53,
57
+ "dz": 54,
58
+ "i": 55,
59
+ "ɔ": 56,
60
+ "ə": 57,
61
+ "ɾ": 58,
62
+ "iə": 59,
63
+ "r": 60,
64
+ "u": 61,
65
+ "oː": 62,
66
+ "ɛː": 63,
67
+ "ɪː": 64,
68
+ "aɪə": 65,
69
+ "aɪɚ": 66,
70
+ "ɑ̃": 67,
71
+ "ç": 68,
72
+ "ɔ̃": 69,
73
+ "ææ": 70,
74
+ "ɐɐ": 71,
75
+ "ɡʲ": 72,
76
+ "nʲ": 73,
77
+ "iːː": 74,
78
+
79
+ "p⁼": 75,
80
+ "pʰ": 76,
81
+ "t⁼": 77,
82
+ "tʰ": 78,
83
+ "k⁼": 79,
84
+ "kʰ": 80,
85
+ "x": 81,
86
+ "tʃ⁼": 82,
87
+ "tʃʰ": 83,
88
+ "ts`⁼": 84,
89
+ "ts`ʰ": 85,
90
+ "s`": 86,
91
+ "ɹ`": 87,
92
+ "ts⁼": 88,
93
+ "tsʰ": 89,
94
+ "p⁼wo": 90,
95
+ "p⁼wo→": 91,
96
+ "p⁼wo↑": 92,
97
+ "p⁼wo↓↑": 93,
98
+ "p⁼wo↓": 94,
99
+ "pʰwo": 95,
100
+ "pʰwo→": 96,
101
+ "pʰwo↑": 97,
102
+ "pʰwo↓↑": 98,
103
+ "pʰwo↓": 99,
104
+ "mwo": 100,
105
+ "mwo→": 101,
106
+ "mwo↑": 102,
107
+ "mwo↓↑": 103,
108
+ "mwo↓": 104,
109
+ "fwo": 105,
110
+ "fwo→": 106,
111
+ "fwo↑": 107,
112
+ "fwo↓↑": 108,
113
+ "fwo↓": 109,
114
+ "jɛn": 110,
115
+ "jɛn→": 111,
116
+ "jɛn↑": 112,
117
+ "jɛn↓↑": 113,
118
+ "jɛn↓": 114,
119
+ "ɥæn": 115,
120
+ "ɥæn→": 116,
121
+ "ɥæn↑": 117,
122
+ "ɥæn↓↑": 118,
123
+ "ɥæn↓": 119,
124
+ "in": 120,
125
+ "in→": 121,
126
+ "in↑": 122,
127
+ "in↓↑": 123,
128
+ "in↓": 124,
129
+ "ɥn": 125,
130
+ "ɥn→": 126,
131
+ "ɥn↑": 127,
132
+ "ɥn↓↑": 128,
133
+ "ɥn↓": 129,
134
+ "iŋ": 130,
135
+ "iŋ→": 131,
136
+ "iŋ↑": 132,
137
+ "iŋ↓↑": 133,
138
+ "iŋ↓": 134,
139
+ "ʊŋ": 135,
140
+ "ʊŋ→": 136,
141
+ "ʊŋ↑": 137,
142
+ "ʊŋ↓↑": 138,
143
+ "ʊŋ↓": 139,
144
+ "jʊŋ": 140,
145
+ "jʊŋ→": 141,
146
+ "jʊŋ↑": 142,
147
+ "jʊŋ↓↑": 143,
148
+ "jʊŋ↓": 144,
149
+ "ia": 145,
150
+ "ia→": 146,
151
+ "ia↑": 147,
152
+ "ia↓↑": 148,
153
+ "ia↓": 149,
154
+ "iɛ": 150,
155
+ "iɛ→": 151,
156
+ "iɛ↑": 152,
157
+ "iɛ↓↑": 153,
158
+ "iɛ↓": 154,
159
+ "iɑʊ": 155,
160
+ "iɑʊ→": 156,
161
+ "iɑʊ↑": 157,
162
+ "iɑʊ↓↑": 158,
163
+ "iɑʊ↓": 159,
164
+ "ioʊ": 160,
165
+ "ioʊ→": 161,
166
+ "ioʊ↑": 162,
167
+ "ioʊ↓↑": 163,
168
+ "ioʊ↓": 164,
169
+ "iɑŋ": 165,
170
+ "iɑŋ→": 166,
171
+ "iɑŋ↑": 167,
172
+ "iɑŋ↓↑": 168,
173
+ "iɑŋ↓": 169,
174
+ "ua": 170,
175
+ "ua→": 171,
176
+ "ua↑": 172,
177
+ "ua↓↑": 173,
178
+ "ua↓": 174,
179
+ "uo": 175,
180
+ "uo→": 176,
181
+ "uo↑": 177,
182
+ "uo↓↑": 178,
183
+ "uo↓": 179,
184
+ "uaɪ": 180,
185
+ "uaɪ→": 181,
186
+ "uaɪ↑": 182,
187
+ "uaɪ↓↑": 183,
188
+ "uaɪ↓": 184,
189
+ "ueɪ": 185,
190
+ "ueɪ→": 186,
191
+ "ueɪ↑": 187,
192
+ "ueɪ↓↑": 188,
193
+ "ueɪ↓": 189,
194
+ "uan": 190,
195
+ "uan→": 191,
196
+ "uan↑": 192,
197
+ "uan↓↑": 193,
198
+ "uan↓": 194,
199
+ "uən": 195,
200
+ "uən→": 196,
201
+ "uən↑": 197,
202
+ "uən↓↑": 198,
203
+ "uən↓": 199,
204
+ "uɑŋ": 200,
205
+ "uɑŋ→": 201,
206
+ "uɑŋ↑": 202,
207
+ "uɑŋ↓↑": 203,
208
+ "uɑŋ↓": 204,
209
+ "ɥɛ": 205,
210
+ "ɥɛ→": 206,
211
+ "ɥɛ↑": 207,
212
+ "ɥɛ↓↑": 208,
213
+ "ɥɛ↓": 209,
214
+ "a": 210,
215
+ "a→": 211,
216
+ "a↑": 212,
217
+ "a↓↑": 213,
218
+ "a↓": 214,
219
+ "o": 215,
220
+ "o→": 216,
221
+ "o↑": 217,
222
+ "o↓↑": 218,
223
+ "o↓": 219,
224
+ "ə→": 220,
225
+ "ə↑": 221,
226
+ "ə↓↑": 222,
227
+ "ə↓": 223,
228
+ "ɛ→": 224,
229
+ "ɛ↑": 225,
230
+ "ɛ↓↑": 226,
231
+ "ɛ↓": 227,
232
+ "aɪ→": 228,
233
+ "aɪ↑": 229,
234
+ "aɪ↓↑": 230,
235
+ "aɪ↓": 231,
236
+ "eɪ→": 232,
237
+ "eɪ↑": 233,
238
+ "eɪ↓↑": 234,
239
+ "eɪ↓": 235,
240
+ "ɑʊ": 236,
241
+ "ɑʊ→": 237,
242
+ "ɑʊ↑": 238,
243
+ "ɑʊ↓↑": 239,
244
+ "ɑʊ↓": 240,
245
+ "oʊ→": 241,
246
+ "oʊ↑": 242,
247
+ "oʊ↓↑": 243,
248
+ "oʊ↓": 244,
249
+ "an": 245,
250
+ "an→": 246,
251
+ "an↑": 247,
252
+ "an↓↑": 248,
253
+ "an↓": 249,
254
+ "ən": 250,
255
+ "ən→": 251,
256
+ "ən↑": 252,
257
+ "ən↓↑": 253,
258
+ "ən↓": 254,
259
+ "ɑŋ": 255,
260
+ "ɑŋ→": 256,
261
+ "ɑŋ↑": 257,
262
+ "ɑŋ↓↑": 258,
263
+ "ɑŋ↓": 259,
264
+ "əŋ": 260,
265
+ "əŋ→": 261,
266
+ "əŋ↑": 262,
267
+ "əŋ↓↑": 263,
268
+ "əŋ↓": 264,
269
+ "əɹ": 265,
270
+ "əɹ→": 266,
271
+ "əɹ↑": 267,
272
+ "əɹ↓↑": 268,
273
+ "əɹ↓": 269,
274
+ "i→": 270,
275
+ "i↑": 271,
276
+ "i↓↑": 272,
277
+ "i↓": 273,
278
+ "u→": 274,
279
+ "u↑": 275,
280
+ "u↓↑": 276,
281
+ "u↓": 277,
282
+ "ɥ": 278,
283
+ "ɥ→": 279,
284
+ "ɥ↑": 280,
285
+ "ɥ↓↑": 281,
286
+ "ɥ↓": 282,
287
+ "ts`⁼ɹ": 283,
288
+ "ts`⁼ɹ→": 284,
289
+ "ts`⁼ɹ↑": 285,
290
+ "ts`⁼ɹ↓↑": 286,
291
+ "ts`⁼ɹ↓": 287,
292
+ "ts`ʰɹ": 288,
293
+ "ts`ʰɹ→": 289,
294
+ "ts`ʰɹ↑": 290,
295
+ "ts`ʰɹ↓↑": 291,
296
+ "ts`ʰɹ↓": 292,
297
+ "s`ɹ": 293,
298
+ "s`ɹ→": 294,
299
+ "s`ɹ↑": 295,
300
+ "s`ɹ↓↑": 296,
301
+ "s`ɹ���": 297,
302
+ "ɹ`ɹ": 298,
303
+ "ɹ`ɹ→": 299,
304
+ "ɹ`ɹ↑": 300,
305
+ "ɹ`ɹ↓↑": 301,
306
+ "ɹ`ɹ↓": 302,
307
+ "ts⁼ɹ": 303,
308
+ "ts⁼ɹ→": 304,
309
+ "ts⁼ɹ↑": 305,
310
+ "ts⁼ɹ↓↑": 306,
311
+ "ts⁼ɹ↓": 307,
312
+ "tsʰɹ": 308,
313
+ "tsʰɹ→": 309,
314
+ "tsʰɹ↑": 310,
315
+ "tsʰɹ↓↑": 311,
316
+ "tsʰɹ↓": 312,
317
+ "sɹ": 313,
318
+ "sɹ→": 314,
319
+ "sɹ↑": 315,
320
+ "sɹ↓↑": 316,
321
+ "sɹ↓": 317
322
+ }
323
+ }