pytorch / pages /24_GANs.py
eaglelandsonce's picture
Rename pages/23_GANs.py to pages/24_GANs.py
0c517c9 verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
# Define the Generator
class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, output_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# Define the Discriminator
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# Hyperparameters
latent_dim = 100
image_dim = 28 * 28 # MNIST images are 28x28 pixels
lr = 0.0002
batch_size = 64
# Prepare the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize the models
generator = Generator(latent_dim, image_dim)
discriminator = Discriminator(image_dim)
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
# Loss function
criterion = nn.BCELoss()
# Streamlit interface
st.title("GAN with PyTorch and Hugging Face")
st.write("Training a GAN to generate MNIST digits")
# Slider for epochs
epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50)
train_gan = st.button("Train GAN")
if train_gan:
# Training loop
for epoch in range(epochs):
for i, (imgs, _) in enumerate(dataloader):
# Prepare real and fake data
real_imgs = imgs.view(imgs.size(0), -1)
real_labels = torch.ones(imgs.size(0), 1)
fake_labels = torch.zeros(imgs.size(0), 1)
z = torch.randn(imgs.size(0), latent_dim)
fake_imgs = generator(z)
# Train Discriminator
optimizer_D.zero_grad()
real_loss = criterion(discriminator(real_imgs), real_labels)
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
g_loss = criterion(discriminator(fake_imgs), real_labels)
g_loss.backward()
optimizer_G.step()
st.write(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
st.write("Training completed")
# Generate and display images
z = torch.randn(16, latent_dim)
generated_imgs = generator(z).view(-1, 1, 28, 28).detach().cpu().numpy()
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for img, ax in zip(generated_imgs, axes.flatten()):
ax.imshow(img.reshape(28, 28), cmap="gray")
ax.axis('off')
st.pyplot(fig)
else:
st.write("Use the slider to select the number of epochs and click the button to start training the GAN")