Spaces:
Runtime error
Runtime error
| from os import mkdir | |
| from os.path import exists | |
| import numpy as np | |
| import torch | |
| from torch.autograd import Variable | |
| from torch.utils.data import DataLoader | |
| from torchvision.utils import save_image | |
| from image_dataset import ImageDataset | |
| from discriminator import Discriminator | |
| from generator import Generator | |
| class ImageWgan: | |
| def __init__( | |
| self, | |
| image_shape: (int, int, int), | |
| latent_space_dimension: int = 100, | |
| use_cuda: bool = False, | |
| generator_saved_model: str or None = None, | |
| discriminator_saved_model: str or None = None | |
| ): | |
| self.generator = Generator(image_shape, latent_space_dimension, use_cuda, generator_saved_model) | |
| self.discriminator = Discriminator(image_shape, use_cuda, discriminator_saved_model) | |
| self.image_shape = image_shape | |
| self.latent_space_dimension = latent_space_dimension | |
| self.use_cuda = use_cuda | |
| if use_cuda: | |
| self.generator.cuda() | |
| self.discriminator.cuda() | |
| def train( | |
| self, | |
| image_dataset: ImageDataset, | |
| learning_rate: float = 0.00005, | |
| batch_size: int = 64, | |
| workers: int = 8, | |
| epochs: int = 100, | |
| clip_value: float = 0.01, | |
| discriminator_steps: int = 5, | |
| sample_interval: int = 1000, | |
| sample_folder: str = 'samples', | |
| generator_save_file: str = 'generator.model', | |
| discriminator_save_file: str = 'discriminator.model' | |
| ): | |
| if not exists(sample_folder): | |
| mkdir(sample_folder) | |
| generator_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=learning_rate) | |
| discriminator_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=learning_rate) | |
| Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor | |
| data_loader = torch.utils.data.DataLoader( | |
| image_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=workers | |
| ) | |
| batches_done = 0 | |
| for epoch in range(epochs): | |
| for i, imgs in enumerate(data_loader): | |
| real_imgs = Variable(imgs.type(Tensor)) | |
| discriminator_optimizer.zero_grad() | |
| # Sample noise as generator input | |
| z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_space_dimension)))) | |
| fake_imgs = self.generator(z).detach() | |
| # Adversarial loss | |
| discriminator_loss = -torch.mean(self.discriminator(real_imgs)) + torch.mean(self.discriminator(fake_imgs)) | |
| discriminator_loss.backward() | |
| discriminator_optimizer.step() | |
| # Clip weights of discriminator | |
| for p in self.discriminator.parameters(): | |
| p.data.clamp_(-clip_value, clip_value) | |
| # Train the generator every n_critic iterations | |
| if i % discriminator_steps == 0: | |
| generator_optimizer.zero_grad() | |
| # Generate a batch of images | |
| gen_imgs = self.generator(z) | |
| # Adversarial loss | |
| generator_loss = -torch.mean(self.discriminator(gen_imgs)) | |
| generator_loss.backward() | |
| generator_optimizer.step() | |
| print( | |
| f'[Epoch {epoch}/{epochs}] [Batch {batches_done % len(data_loader)}/{len(data_loader)}] ' + | |
| f'[D loss: {discriminator_loss.item()}] [G loss: {generator_loss.item()}]' | |
| ) | |
| if batches_done % sample_interval == 0: | |
| save_image(gen_imgs.data[:25], f'{sample_folder}/{batches_done}.png', nrow=5, normalize=True) | |
| batches_done += 1 | |
| self.discriminator.save(discriminator_save_file) | |
| self.generator.save(generator_save_file) | |
| def generate( | |
| self, | |
| sample_folder: str = 'samples' | |
| ): | |
| if not exists(sample_folder): | |
| mkdir(sample_folder) | |
| Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor | |
| z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension)))) | |
| gen_imgs = self.generator(z) | |
| generator_loss = -torch.mean(self.discriminator(gen_imgs)) | |
| generator_loss.backward() | |
| save_image(gen_imgs.data[:25], f'{sample_folder}/generated.png', nrow=5, normalize=True) | |