Spaces:
Build error
Build error
| # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py | |
| # reference: https://github.com/lifeiteng/vall-e | |
| import os, sys | |
| now_dir = os.getcwd() | |
| sys.path.append(now_dir) | |
| from typing import Dict | |
| import torch | |
| from pytorch_lightning import LightningModule | |
| from AR.models.t2s_model import Text2SemanticDecoder | |
| from AR.modules.lr_schedulers import WarmupCosineLRSchedule | |
| from AR.modules.optim import ScaledAdam | |
| class Text2SemanticLightningModule(LightningModule): | |
| def __init__(self, config, output_dir, is_train=True): | |
| super().__init__() | |
| self.config = config | |
| self.top_k = 3 | |
| self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) | |
| pretrained_s1 = config.get("pretrained_s1") | |
| if pretrained_s1 and is_train: | |
| # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) | |
| print( | |
| self.load_state_dict( | |
| torch.load(pretrained_s1, map_location="cpu")["weight"] | |
| ) | |
| ) | |
| if is_train: | |
| self.automatic_optimization = False | |
| self.save_hyperparameters() | |
| self.eval_dir = output_dir / "eval" | |
| self.eval_dir.mkdir(parents=True, exist_ok=True) | |
| def training_step(self, batch: Dict, batch_idx: int): | |
| opt = self.optimizers() | |
| scheduler = self.lr_schedulers() | |
| forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old | |
| loss, acc = forward( | |
| batch["phoneme_ids"], | |
| batch["phoneme_ids_len"], | |
| batch["semantic_ids"], | |
| batch["semantic_ids_len"], | |
| batch["bert_feature"], | |
| ) | |
| self.manual_backward(loss) | |
| if batch_idx > 0 and batch_idx % 4 == 0: | |
| opt.step() | |
| opt.zero_grad() | |
| scheduler.step() | |
| self.log( | |
| "total_loss", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| "lr", | |
| scheduler.get_last_lr()[0], | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| self.log( | |
| f"top_{self.top_k}_acc", | |
| acc, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True, | |
| ) | |
| def validation_step(self, batch: Dict, batch_idx: int): | |
| return | |
| # # get loss | |
| # loss, acc = self.model.forward( | |
| # batch['phoneme_ids'], batch['phoneme_ids_len'], | |
| # batch['semantic_ids'], batch['semantic_ids_len'], | |
| # batch['bert_feature'] | |
| # ) | |
| # | |
| # self.log( | |
| # "val_total_loss", | |
| # loss, | |
| # on_step=True, | |
| # on_epoch=True, | |
| # prog_bar=True, | |
| # sync_dist=True) | |
| # self.log( | |
| # f"val_top_{self.top_k}_acc", | |
| # acc, | |
| # on_step=True, | |
| # on_epoch=True, | |
| # prog_bar=True, | |
| # sync_dist=True) | |
| # | |
| # # get infer output | |
| # semantic_len = batch['semantic_ids'].size(1) | |
| # prompt_len = min(int(semantic_len * 0.5), 150) | |
| # prompt = batch['semantic_ids'][:, :prompt_len] | |
| # pred_semantic = self.model.infer(batch['phoneme_ids'], | |
| # batch['phoneme_ids_len'], prompt, | |
| # batch['bert_feature'] | |
| # ) | |
| # save_name = f'semantic_toks_{batch_idx}.pt' | |
| # save_path = os.path.join(self.eval_dir, save_name) | |
| # torch.save(pred_semantic.detach().cpu(), save_path) | |
| def configure_optimizers(self): | |
| model_parameters = self.model.parameters() | |
| parameters_names = [] | |
| parameters_names.append( | |
| [name_param_pair[0] for name_param_pair in self.model.named_parameters()] | |
| ) | |
| lm_opt = ScaledAdam( | |
| model_parameters, | |
| lr=0.01, | |
| betas=(0.9, 0.95), | |
| clipping_scale=2.0, | |
| parameters_names=parameters_names, | |
| show_dominant_parameters=False, | |
| clipping_update_period=1000, | |
| ) | |
| return { | |
| "optimizer": lm_opt, | |
| "lr_scheduler": { | |
| "scheduler": WarmupCosineLRSchedule( | |
| lm_opt, | |
| init_lr=self.config["optimizer"]["lr_init"], | |
| peak_lr=self.config["optimizer"]["lr"], | |
| end_lr=self.config["optimizer"]["lr_end"], | |
| warmup_steps=self.config["optimizer"]["warmup_steps"], | |
| total_steps=self.config["optimizer"]["decay_steps"], | |
| ) | |
| }, | |
| } | |