Spaces:
Running
Running
| # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_lightning_module.py | |
| import os,sys | |
| now_dir = os.getcwd() | |
| sys.path.append(now_dir) | |
| from typing import Dict | |
| import torch | |
| from pytorch_lightning import LightningModule | |
| from AR.models.t2s_model import Text2SemanticDecoder | |
| from AR.modules.lr_schedulers import WarmupCosineLRSchedule | |
| from AR.modules.optim import ScaledAdam | |
| class Text2SemanticLightningModule(LightningModule): | |
| def __init__(self, config, output_dir,is_train=True): | |
| super().__init__() | |
| self.config = config | |
| self.top_k = 3 | |
| self.model = Text2SemanticDecoder(config=config, top_k=self.top_k) | |
| pretrained_s1=config.get("pretrained_s1") | |
| if(pretrained_s1 and is_train): | |
| # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) | |
| print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["weight"])) | |
| if is_train: | |
| self.automatic_optimization = False | |
| self.save_hyperparameters() | |
| self.eval_dir = output_dir / 'eval' | |
| self.eval_dir.mkdir(parents=True, exist_ok=True) | |
| def training_step(self, batch: Dict, batch_idx: int): | |
| opt = self.optimizers() | |
| scheduler = self.lr_schedulers() | |
| loss, acc = self.model.forward( | |
| batch['phoneme_ids'], batch['phoneme_ids_len'], | |
| batch['semantic_ids'], batch['semantic_ids_len'], | |
| batch['bert_feature']) | |
| self.manual_backward(loss) | |
| if batch_idx > 0 and batch_idx % 4 == 0: | |
| opt.step() | |
| opt.zero_grad() | |
| scheduler.step() | |
| self.log( | |
| "total_loss", | |
| loss, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True) | |
| self.log( | |
| "lr", | |
| scheduler.get_last_lr()[0], | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True) | |
| self.log( | |
| f"top_{self.top_k}_acc", | |
| acc, | |
| on_step=True, | |
| on_epoch=True, | |
| prog_bar=True, | |
| sync_dist=True) | |
| def validation_step(self, batch: Dict, batch_idx: int):return | |
| # # get loss | |
| # loss, acc = self.model.forward( | |
| # batch['phoneme_ids'], batch['phoneme_ids_len'], | |
| # batch['semantic_ids'], batch['semantic_ids_len'], | |
| # batch['bert_feature'] | |
| # ) | |
| # | |
| # self.log( | |
| # "val_total_loss", | |
| # loss, | |
| # on_step=True, | |
| # on_epoch=True, | |
| # prog_bar=True, | |
| # sync_dist=True) | |
| # self.log( | |
| # f"val_top_{self.top_k}_acc", | |
| # acc, | |
| # on_step=True, | |
| # on_epoch=True, | |
| # prog_bar=True, | |
| # sync_dist=True) | |
| # | |
| # # get infer output | |
| # semantic_len = batch['semantic_ids'].size(1) | |
| # prompt_len = min(int(semantic_len * 0.5), 150) | |
| # prompt = batch['semantic_ids'][:, :prompt_len] | |
| # pred_semantic = self.model.infer(batch['phoneme_ids'], | |
| # batch['phoneme_ids_len'], prompt, | |
| # batch['bert_feature'] | |
| # ) | |
| # save_name = f'semantic_toks_{batch_idx}.pt' | |
| # save_path = os.path.join(self.eval_dir, save_name) | |
| # torch.save(pred_semantic.detach().cpu(), save_path) | |
| def configure_optimizers(self): | |
| model_parameters = self.model.parameters() | |
| parameters_names = [] | |
| parameters_names.append([ | |
| name_param_pair[0] | |
| for name_param_pair in self.model.named_parameters() | |
| ]) | |
| lm_opt = ScaledAdam( | |
| model_parameters, | |
| lr=0.01, | |
| betas=(0.9, 0.95), | |
| clipping_scale=2.0, | |
| parameters_names=parameters_names, | |
| show_dominant_parameters=False, | |
| clipping_update_period=1000, ) | |
| return { | |
| "optimizer": lm_opt, | |
| "lr_scheduler": { | |
| "scheduler": | |
| WarmupCosineLRSchedule( | |
| lm_opt, | |
| init_lr=self.config['optimizer']['lr_init'], | |
| peak_lr=self.config['optimizer']['lr'], | |
| end_lr=self.config['optimizer']['lr_end'], | |
| warmup_steps=self.config['optimizer']['warmup_steps'], | |
| total_steps=self.config['optimizer']['decay_steps']) | |
| } | |
| } | |