File size: 3,880 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from torch import nn
from .espnet_positional_embedding import RelPositionalEncoding
from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
from ..layers import Embedding


def sequence_mask(length, max_length=None):
    if max_length is None:
        max_length = length.max()
    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
    return x.unsqueeze(0) < length.unsqueeze(1)


class ConformerLayers(nn.Module):
    def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4, use_last_norm=True):
        super().__init__()
        self.use_last_norm = use_last_norm
        self.layers = nn.ModuleList()
        positionwise_layer = MultiLayeredConv1d
        positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
        self.encoder_layers = nn.ModuleList([EncoderLayer(
            hidden_size,
            MultiHeadedAttention(num_heads, hidden_size, 0.0),
            positionwise_layer(*positionwise_layer_args),
            positionwise_layer(*positionwise_layer_args),
            ConvolutionModule(hidden_size, kernel_size, Swish()),
            dropout,
        ) for _ in range(num_layers)])
        if self.use_last_norm:
            self.layer_norm = nn.LayerNorm(hidden_size)
        else:
            self.layer_norm = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, x_mask):
        """

        :param x: [B, T, H]
        :param padding_mask: [B, T]
        :return: [B, T, H]
        """
        for l in self.encoder_layers:
            x, mask = l(x, x_mask)
        x = self.layer_norm(x) * x_mask
        return x


class ConformerEncoder(ConformerLayers):
    def __init__(self, hidden_size, dict_size=0, in_size=0, strides=[2,2], num_layers=None):
        conformer_enc_kernel_size = 9
        super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
        self.dict_size = dict_size
        if dict_size != 0:
            self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
        else:
            self.seq_proj_in = torch.nn.Linear(in_size, hidden_size)
            self.seq_proj_out = torch.nn.Linear(hidden_size, in_size)
        self.mel_in = torch.nn.Linear(160, hidden_size)
        self.mel_pre_net = torch.nn.Sequential(*[
                torch.nn.Conv1d(hidden_size, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2)
                for i, s in enumerate(strides)
            ])

    def forward(self, seq_out, mels_timbre, other_embeds=0):
        """

        :param src_tokens: [B, T]
        :return: [B x T x C]
        """
        x_lengths = (seq_out > 0).long().sum(-1)
        x = seq_out
        if self.dict_size != 0:
            x = self.embed(x) + other_embeds  # [B, T, H]
        else:
            x = self.seq_proj_in(x) + other_embeds  # [B, T, H]
        mels_timbre = self.mel_in(mels_timbre).transpose(1, 2)
        mels_timbre = self.mel_pre_net(mels_timbre).transpose(1, 2)

        T_out = x.size(1)
        if self.dict_size != 0:
            x_mask = torch.unsqueeze(sequence_mask(x_lengths + mels_timbre.size(1), x.size(1) + mels_timbre.size(1)), 2).to(x.dtype)
        else:
            x_mask = torch.cat((torch.ones(x.size(0), mels_timbre.size(1), 1).to(x.device), (x.abs().sum(2) > 0).float()[:, :, None]), dim=1)
        x = torch.cat((mels_timbre, x), 1)
        x = super(ConformerEncoder, self).forward(x, x_mask)
        if self.dict_size != 0:
            x = x[:, -T_out:, :]
        else:
            x = self.seq_proj_out(x[:, -T_out:, :])
        return x


class ConformerDecoder(ConformerLayers):
    def __init__(self, hidden_size, num_layers):
        conformer_dec_kernel_size = 9
        super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)