Minecraft_Skin_Generator / image_wgan.py
meeww's picture
Update image_wgan.py
2825685
from os import mkdir
from os.path import exists
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from image_dataset import ImageDataset
from discriminator import Discriminator
from generator import Generator
class ImageWgan:
def __init__(
self,
image_shape: (int, int, int),
latent_space_dimension: int = 100,
use_cuda: bool = False,
generator_saved_model: str or None = None,
discriminator_saved_model: str or None = None
):
self.generator = Generator(image_shape, latent_space_dimension, use_cuda, generator_saved_model)
self.discriminator = Discriminator(image_shape, use_cuda, discriminator_saved_model)
self.image_shape = image_shape
self.latent_space_dimension = latent_space_dimension
self.use_cuda = use_cuda
if use_cuda:
self.generator.cuda()
self.discriminator.cuda()
def train(
self,
image_dataset: ImageDataset,
learning_rate: float = 0.00005,
batch_size: int = 64,
workers: int = 8,
epochs: int = 100,
clip_value: float = 0.01,
discriminator_steps: int = 5,
sample_interval: int = 1000,
sample_folder: str = 'samples',
generator_save_file: str = 'generator.model',
discriminator_save_file: str = 'discriminator.model'
):
if not exists(sample_folder):
mkdir(sample_folder)
generator_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=learning_rate)
discriminator_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=learning_rate)
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
data_loader = torch.utils.data.DataLoader(
image_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers
)
batches_done = 0
for epoch in range(epochs):
for i, imgs in enumerate(data_loader):
real_imgs = Variable(imgs.type(Tensor))
discriminator_optimizer.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_space_dimension))))
fake_imgs = self.generator(z).detach()
# Adversarial loss
discriminator_loss = -torch.mean(self.discriminator(real_imgs)) + torch.mean(self.discriminator(fake_imgs))
discriminator_loss.backward()
discriminator_optimizer.step()
# Clip weights of discriminator
for p in self.discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# Train the generator every n_critic iterations
if i % discriminator_steps == 0:
generator_optimizer.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(z)
# Adversarial loss
generator_loss = -torch.mean(self.discriminator(gen_imgs))
generator_loss.backward()
generator_optimizer.step()
print(
f'[Epoch {epoch}/{epochs}] [Batch {batches_done % len(data_loader)}/{len(data_loader)}] ' +
f'[D loss: {discriminator_loss.item()}] [G loss: {generator_loss.item()}]'
)
if batches_done % sample_interval == 0:
save_image(gen_imgs.data[:25], f'{sample_folder}/{batches_done}.png', nrow=5, normalize=True)
batches_done += 1
self.discriminator.save(discriminator_save_file)
self.generator.save(generator_save_file)
def generate(
self,
sample_folder: str = 'samples'
):
if not exists(sample_folder):
mkdir(sample_folder)
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension))))
gen_imgs = self.generator(z)
generator_loss = -torch.mean(self.discriminator(gen_imgs))
generator_loss.backward()
save_image(gen_imgs.data[:25], f'{sample_folder}/generated.png', nrow=5, normalize=True)