Steveeeeeeen's picture
Steveeeeeeen HF Staff
add model
7e6946d
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from typing import Tuple, Union
import torch
class BaseSubsampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.right_context = 0
self.subsampling_rate = 1
def position_encoding(self, offset: Union[int, torch.Tensor],
size: int) -> torch.Tensor:
return self.pos_enc.position_encoding(offset, size)
class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
super().__init__()
self.out = torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim, eps=1e-5),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
self.right_context = 0
self.subsampling_rate = 1
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.out(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask