AlexK-PL's picture
Upload 72 files
c61c48a
raw
history blame
800 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from .discriminator import Discriminator
from .identity import Identity
class MultiScaleDiscriminator(nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList(
[Discriminator() for _ in range(3)]
)
self.pooling = nn.ModuleList(
[Identity()] +
[nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)]
)
def forward(self, x):
ret = list()
for pool, disc in zip(self.pooling, self.discriminators):
x = pool(x)
ret.append(disc(x))
return ret # [(feat, score), (feat, score), (feat, score)]