Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| from typing import Optional | |
| import six | |
| import torch | |
| import numpy as np | |
| def sequence_mask( | |
| lengths, | |
| maxlen: Optional[int] = None, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| if maxlen is None: | |
| maxlen = lengths.max() | |
| row_vector = torch.arange(0, maxlen, 1).to(lengths.device) | |
| matrix = torch.unsqueeze(lengths, dim=-1) | |
| mask = row_vector < matrix | |
| mask = mask.detach() | |
| return mask.type(dtype).to(device) if device is not None else mask.type(dtype) | |
| def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))): | |
| """End detection. | |
| described in Eq. (50) of S. Watanabe et al | |
| "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" | |
| :param ended_hyps: | |
| :param i: | |
| :param M: | |
| :param d_end: | |
| :return: | |
| """ | |
| if len(ended_hyps) == 0: | |
| return False | |
| count = 0 | |
| best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] | |
| for m in six.moves.range(M): | |
| # get ended_hyps with their length is i - m | |
| hyp_length = i - m | |
| hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] | |
| if len(hyps_same_length) > 0: | |
| best_hyp_same_length = sorted( | |
| hyps_same_length, key=lambda x: x["score"], reverse=True | |
| )[0] | |
| if best_hyp_same_length["score"] - best_hyp["score"] < d_end: | |
| count += 1 | |
| if count == M: | |
| return True | |
| else: | |
| return False | |