|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
class PositionwiseFeedForward(torch.nn.Module): |
|
""" |
|
Args: |
|
idim (int): Input dimenstion. |
|
hidden_units (int): The number of hidden units. |
|
dropout_rate (float): Dropout rate. |
|
|
|
""" |
|
|
|
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): |
|
super(PositionwiseFeedForward, self).__init__() |
|
self.w_1 = torch.nn.Linear(idim, hidden_units) |
|
self.w_2 = torch.nn.Linear(hidden_units, idim) |
|
self.dropout = torch.nn.Dropout(dropout_rate) |
|
self.activation = activation |
|
|
|
def forward(self, x): |
|
return self.w_2(self.dropout(self.activation(self.w_1(x)))) |
|
|