Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from src.smb.level import MarioLevel | |
from src.utils.dl import SelfAttn | |
nz = 20 | |
# Self Attention GAN | |
class SAGenerator(nn.Module): | |
def __init__(self, base_channels=32): | |
super(SAGenerator, self).__init__() | |
self.main = nn.Sequential( | |
nn.utils.spectral_norm(nn.ConvTranspose2d(nz, base_channels * 4, 4)), | |
nn.BatchNorm2d(base_channels * 4), nn.ReLU(), | |
nn.utils.spectral_norm(nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels * 2), nn.ReLU(), | |
SelfAttn(base_channels * 2), | |
nn.utils.spectral_norm(nn.ConvTranspose2d(base_channels * 2, base_channels, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels), nn.ReLU(), | |
SelfAttn(base_channels), | |
nn.utils.spectral_norm(nn.ConvTranspose2d(base_channels, MarioLevel.n_types, 3, 1, 1)), | |
nn.Softmax(dim=1) | |
) | |
def forward(self, x): | |
return self.main(x) | |
class SADiscriminator(nn.Module): | |
def __init__(self, base_channels=32): | |
super(SADiscriminator, self).__init__() | |
self.main = nn.Sequential( | |
nn.utils.spectral_norm(nn.Conv2d(MarioLevel.n_types, base_channels, 3, 1, 1)), | |
nn.BatchNorm2d(base_channels), nn.LeakyReLU(0.1), | |
SelfAttn(base_channels), | |
nn.utils.spectral_norm(nn.Conv2d(base_channels, base_channels * 2, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels * 2), nn.LeakyReLU(0.1), | |
SelfAttn(base_channels * 2), | |
nn.utils.spectral_norm(nn.Conv2d(base_channels * 2, base_channels * 4, 4, 2, 1)), | |
nn.BatchNorm2d(base_channels * 4), nn.LeakyReLU(0.1), | |
nn.utils.spectral_norm(nn.Conv2d(base_channels * 4, 1, 4)), | |
nn.Flatten() | |
) | |
def forward(self, x): | |
return self.main(x) | |
if __name__ == '__main__': | |
noise = torch.rand(2, nz, 1, 1) * 2 - 1 | |
netG = SAGenerator() | |
netD = SADiscriminator() | |
# print(netG) | |
X = netG(noise) | |
Y = netD(X) | |
print(X.shape, Y.shape) | |
pass | |