from torch import nn

from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator


class MelganMultiscaleDiscriminator(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 num_scales=3,
                 kernel_sizes=(5, 3),
                 base_channels=16,
                 max_channels=1024,
                 downsample_factors=(4, 4, 4),
                 pooling_kernel_size=4,
                 pooling_stride=2,
                 pooling_padding=1):
        super(MelganMultiscaleDiscriminator, self).__init__()

        self.discriminators = nn.ModuleList([
            MelganDiscriminator(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_sizes=kernel_sizes,
                                base_channels=base_channels,
                                max_channels=max_channels,
                                downsample_factors=downsample_factors)
            for _ in range(num_scales)
        ])

        self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False)


    def forward(self, x):
        scores = list()
        feats = list()
        for disc in self.discriminators:
            score, feat = disc(x)
            scores.append(score)
            feats.append(feat)
            x = self.pooling(x)
        return scores, feats