Spaces:
Running
Running
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import logging | |
| from contextlib import nullcontext | |
| # if your python version < 3.7 use the below one | |
| # from contextlib import suppress as nullcontext | |
| import torch | |
| from torch.nn.utils import clip_grad_norm_ | |
| class Executor: | |
| def __init__(self): | |
| self.step = 0 | |
| def train( | |
| self, model, optimizer, scheduler, data_loader, device, writer, args, scaler | |
| ): | |
| """Train one epoch""" | |
| model.train() | |
| clip = args.get("grad_clip", 50.0) | |
| log_interval = args.get("log_interval", 10) | |
| rank = args.get("rank", 0) | |
| epoch = args.get("epoch", 0) | |
| accum_grad = args.get("accum_grad", 1) | |
| is_distributed = args.get("is_distributed", True) | |
| use_amp = args.get("use_amp", False) | |
| logging.info( | |
| "using accumulate grad, new batch size is {} times" | |
| " larger than before".format(accum_grad) | |
| ) | |
| if use_amp: | |
| assert scaler is not None | |
| # A context manager to be used in conjunction with an instance of | |
| # torch.nn.parallel.DistributedDataParallel to be able to train | |
| # with uneven inputs across participating processes. | |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
| model_context = model.join | |
| else: | |
| model_context = nullcontext | |
| num_seen_utts = 0 | |
| with model_context(): | |
| for batch_idx, batch in enumerate(data_loader): | |
| key, feats, target, feats_lengths, target_lengths = batch | |
| feats = feats.to(device) | |
| target = target.to(device) | |
| feats_lengths = feats_lengths.to(device) | |
| target_lengths = target_lengths.to(device) | |
| num_utts = target_lengths.size(0) | |
| if num_utts == 0: | |
| continue | |
| context = None | |
| # Disable gradient synchronizations across DDP processes. | |
| # Within this context, gradients will be accumulated on module | |
| # variables, which will later be synchronized. | |
| if is_distributed and batch_idx % accum_grad != 0: | |
| context = model.no_sync | |
| # Used for single gpu training and DDP gradient synchronization | |
| # processes. | |
| else: | |
| context = nullcontext | |
| with context(): | |
| # autocast context | |
| # The more details about amp can be found in | |
| # https://pytorch.org/docs/stable/notes/amp_examples.html | |
| with torch.cuda.amp.autocast(scaler is not None): | |
| loss_dict = model(feats, feats_lengths, target, target_lengths) | |
| loss = loss_dict["loss"] / accum_grad | |
| if use_amp: | |
| scaler.scale(loss).backward() | |
| else: | |
| loss.backward() | |
| num_seen_utts += num_utts | |
| if batch_idx % accum_grad == 0: | |
| if rank == 0 and writer is not None: | |
| writer.add_scalar("train_loss", loss, self.step) | |
| # Use mixed precision training | |
| if use_amp: | |
| scaler.unscale_(optimizer) | |
| grad_norm = clip_grad_norm_(model.parameters(), clip) | |
| # Must invoke scaler.update() if unscale_() is used in | |
| # the iteration to avoid the following error: | |
| # RuntimeError: unscale_() has already been called | |
| # on this optimizer since the last update(). | |
| # We don't check grad here since that if the gradient | |
| # has inf/nan values, scaler.step will skip | |
| # optimizer.step(). | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| grad_norm = clip_grad_norm_(model.parameters(), clip) | |
| if torch.isfinite(grad_norm): | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| scheduler.step() | |
| self.step += 1 | |
| if batch_idx % log_interval == 0: | |
| lr = optimizer.param_groups[0]["lr"] | |
| log_str = "TRAIN Batch {}/{} loss {:.6f} ".format( | |
| epoch, batch_idx, loss.item() * accum_grad | |
| ) | |
| for name, value in loss_dict.items(): | |
| if name != "loss" and value is not None: | |
| log_str += "{} {:.6f} ".format(name, value.item()) | |
| log_str += "lr {:.8f} rank {}".format(lr, rank) | |
| logging.debug(log_str) | |
| def cv(self, model, data_loader, device, args): | |
| """Cross validation on""" | |
| model.eval() | |
| rank = args.get("rank", 0) | |
| epoch = args.get("epoch", 0) | |
| log_interval = args.get("log_interval", 10) | |
| # in order to avoid division by 0 | |
| num_seen_utts = 1 | |
| total_loss = 0.0 | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(data_loader): | |
| key, feats, target, feats_lengths, target_lengths = batch | |
| feats = feats.to(device) | |
| target = target.to(device) | |
| feats_lengths = feats_lengths.to(device) | |
| target_lengths = target_lengths.to(device) | |
| num_utts = target_lengths.size(0) | |
| if num_utts == 0: | |
| continue | |
| loss_dict = model(feats, feats_lengths, target, target_lengths) | |
| loss = loss_dict["loss"] | |
| if torch.isfinite(loss): | |
| num_seen_utts += num_utts | |
| total_loss += loss.item() * num_utts | |
| if batch_idx % log_interval == 0: | |
| log_str = "CV Batch {}/{} loss {:.6f} ".format( | |
| epoch, batch_idx, loss.item() | |
| ) | |
| for name, value in loss_dict.items(): | |
| if name != "loss" and value is not None: | |
| log_str += "{} {:.6f} ".format(name, value.item()) | |
| log_str += "history loss {:.6f}".format(total_loss / num_seen_utts) | |
| log_str += " rank {}".format(rank) | |
| logging.debug(log_str) | |
| return total_loss, num_seen_utts | |