Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch.nn as nn | |
| import torch | |
| class Generator(nn.Module): | |
| def __init__( | |
| self, | |
| image_shape: (int, int, int), | |
| latent_space_dimension: int, | |
| use_cuda: bool = False, | |
| saved_model: str or None = None | |
| ): | |
| super(Generator, self).__init__() | |
| self.image_shape = image_shape | |
| def block(in_feat, out_feat, normalize=True): | |
| layers = [nn.Linear(in_feat, out_feat)] | |
| if normalize: | |
| layers.append(nn.BatchNorm1d(out_feat, 0.8)) | |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
| return layers | |
| self.model = nn.Sequential( | |
| *block(latent_space_dimension, 128, normalize=False), | |
| *block(128, 256), | |
| *block(256, 512), | |
| *block(512, 1024), | |
| nn.Linear(1024, int(np.prod(image_shape))), | |
| nn.Tanh() | |
| ) | |
| if saved_model is not None: | |
| self.model.load_state_dict( | |
| torch.load( | |
| saved_model, | |
| map_location=torch.device('cuda' if use_cuda else 'cpu') | |
| ) | |
| ) | |
| def forward(self, z): | |
| img = self.model(z) | |
| img = img.view(img.shape[0], *self.image_shape) | |
| return img | |
| def save(self, to): | |
| torch.save(self.model.state_dict(), to) | |