Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .discriminator import Discriminator | |
from .identity import Identity | |
class MultiScaleDiscriminator(nn.Module): | |
def __init__(self): | |
super(MultiScaleDiscriminator, self).__init__() | |
self.discriminators = nn.ModuleList( | |
[Discriminator() for _ in range(3)] | |
) | |
self.pooling = nn.ModuleList( | |
[Identity()] + | |
[nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) for _ in range(1, 3)] | |
) | |
def forward(self, x): | |
ret = list() | |
for pool, disc in zip(self.pooling, self.discriminators): | |
x = pool(x) | |
ret.append(disc(x)) | |
return ret # [(feat, score), (feat, score), (feat, score)] | |