baiyanlali-zhao's picture
添加注释
3582c8a
import torch
from torch import nn
from src.smb.level import MarioLevel
from src.utils.dl import SelfAttn
nz = 20
# Self Attention GAN
class SAGenerator(nn.Module):
def __init__(self, base_channels=32):
super(SAGenerator, self).__init__()
self.main = nn.Sequential(
nn.utils.spectral_norm(nn.ConvTranspose2d(nz, base_channels * 4, 4)),
nn.BatchNorm2d(base_channels * 4), nn.ReLU(),
nn.utils.spectral_norm(nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, 2, 1)),
nn.BatchNorm2d(base_channels * 2), nn.ReLU(),
SelfAttn(base_channels * 2),
nn.utils.spectral_norm(nn.ConvTranspose2d(base_channels * 2, base_channels, 4, 2, 1)),
nn.BatchNorm2d(base_channels), nn.ReLU(),
SelfAttn(base_channels),
nn.utils.spectral_norm(nn.ConvTranspose2d(base_channels, MarioLevel.n_types, 3, 1, 1)),
nn.Softmax(dim=1)
)
def forward(self, x):
return self.main(x)
class SADiscriminator(nn.Module):
def __init__(self, base_channels=32):
super(SADiscriminator, self).__init__()
self.main = nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(MarioLevel.n_types, base_channels, 3, 1, 1)),
nn.BatchNorm2d(base_channels), nn.LeakyReLU(0.1),
SelfAttn(base_channels),
nn.utils.spectral_norm(nn.Conv2d(base_channels, base_channels * 2, 4, 2, 1)),
nn.BatchNorm2d(base_channels * 2), nn.LeakyReLU(0.1),
SelfAttn(base_channels * 2),
nn.utils.spectral_norm(nn.Conv2d(base_channels * 2, base_channels * 4, 4, 2, 1)),
nn.BatchNorm2d(base_channels * 4), nn.LeakyReLU(0.1),
nn.utils.spectral_norm(nn.Conv2d(base_channels * 4, 1, 4)),
nn.Flatten()
)
def forward(self, x):
return self.main(x)
if __name__ == '__main__':
noise = torch.rand(2, nz, 1, 1) * 2 - 1
netG = SAGenerator()
netD = SADiscriminator()
# print(netG)
X = netG(noise)
Y = netD(X)
print(X.shape, Y.shape)
pass