Spaces:
Running
Running
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| """ScorerInterface implementation for CTC.""" | |
| import numpy as np | |
| import torch | |
| from modules.wenet_extractor.paraformer.search.ctc_prefix_score import CTCPrefixScore | |
| from modules.wenet_extractor.paraformer.search.ctc_prefix_score import CTCPrefixScoreTH | |
| from modules.wenet_extractor.paraformer.search.scorer_interface import ( | |
| BatchPartialScorerInterface, | |
| ) | |
| class CTCPrefixScorer(BatchPartialScorerInterface): | |
| """Decoder interface wrapper for CTCPrefixScore.""" | |
| def __init__(self, ctc: torch.nn.Module, eos: int): | |
| """Initialize class. | |
| Args: | |
| ctc (torch.nn.Module): The CTC implementation. | |
| For example, :class:`espnet.nets.pytorch_backend.ctc.CTC` | |
| eos (int): The end-of-sequence id. | |
| """ | |
| self.ctc = ctc | |
| self.eos = eos | |
| self.impl = None | |
| def init_state(self, x: torch.Tensor): | |
| """Get an initial state for decoding. | |
| Args: | |
| x (torch.Tensor): The encoded feature tensor | |
| Returns: initial state | |
| """ | |
| logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() | |
| # TODO(karita): use CTCPrefixScoreTH | |
| self.impl = CTCPrefixScore(logp, 0, self.eos, np) | |
| return 0, self.impl.initial_state() | |
| def select_state(self, state, i, new_id=None): | |
| """Select state with relative ids in the main beam search. | |
| Args: | |
| state: Decoder state for prefix tokens | |
| i (int): Index to select a state in the main beam search | |
| new_id (int): New label id to select a state if necessary | |
| Returns: | |
| state: pruned state | |
| """ | |
| if type(state) == tuple: | |
| if len(state) == 2: # for CTCPrefixScore | |
| sc, st = state | |
| return sc[i], st[i] | |
| else: # for CTCPrefixScoreTH (need new_id > 0) | |
| r, log_psi, f_min, f_max, scoring_idmap = state | |
| s = log_psi[i, new_id].expand(log_psi.size(1)) | |
| if scoring_idmap is not None: | |
| return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max | |
| else: | |
| return r[:, :, i, new_id], s, f_min, f_max | |
| return None if state is None else state[i] | |
| def score_partial(self, y, ids, state, x): | |
| """Score new token. | |
| Args: | |
| y (torch.Tensor): 1D prefix token | |
| next_tokens (torch.Tensor): torch.int64 next token to score | |
| state: decoder state for prefix tokens | |
| x (torch.Tensor): 2D encoder feature that generates ys | |
| Returns: | |
| tuple[torch.Tensor, Any]: | |
| Tuple of a score tensor for y that has a shape | |
| `(len(next_tokens),)` and next state for ys | |
| """ | |
| prev_score, state = state | |
| presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) | |
| tscore = torch.as_tensor( | |
| presub_score - prev_score, device=x.device, dtype=x.dtype | |
| ) | |
| return tscore, (presub_score, new_st) | |
| def batch_init_state(self, x: torch.Tensor): | |
| """Get an initial state for decoding. | |
| Args: | |
| x (torch.Tensor): The encoded feature tensor | |
| Returns: initial state | |
| """ | |
| logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 | |
| xlen = torch.tensor([logp.size(1)]) | |
| self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) | |
| return None | |
| def batch_score_partial(self, y, ids, state, x): | |
| """Score new token. | |
| Args: | |
| y (torch.Tensor): 1D prefix token | |
| ids (torch.Tensor): torch.int64 next token to score | |
| state: decoder state for prefix tokens | |
| x (torch.Tensor): 2D encoder feature that generates ys | |
| Returns: | |
| tuple[torch.Tensor, Any]: | |
| Tuple of a score tensor for y that has a shape | |
| `(len(next_tokens),)` and next state for ys | |
| """ | |
| batch_state = ( | |
| ( | |
| torch.stack([s[0] for s in state], dim=2), | |
| torch.stack([s[1] for s in state]), | |
| state[0][2], | |
| state[0][3], | |
| ) | |
| if state[0] is not None | |
| else None | |
| ) | |
| return self.impl(y, batch_state, ids) | |
| def extend_prob(self, x: torch.Tensor): | |
| """Extend probs for decoding. | |
| This extension is for streaming decoding | |
| as in Eq (14) in https://arxiv.org/abs/2006.14941 | |
| Args: | |
| x (torch.Tensor): The encoded feature tensor | |
| """ | |
| logp = self.ctc.log_softmax(x.unsqueeze(0)) | |
| self.impl.extend_prob(logp) | |
| def extend_state(self, state): | |
| """Extend state for decoding. | |
| This extension is for streaming decoding | |
| as in Eq (14) in https://arxiv.org/abs/2006.14941 | |
| Args: | |
| state: The states of hyps | |
| Returns: exteded state | |
| """ | |
| new_state = [] | |
| for s in state: | |
| new_state.append(self.impl.extend_state(s)) | |
| return new_state | |