Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
from torch.utils.data import DataLoader | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
# Define the Generator | |
class Generator(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(Generator, self).__init__() | |
self.model = nn.Sequential( | |
nn.Linear(input_dim, 128), | |
nn.ReLU(), | |
nn.Linear(128, 256), | |
nn.ReLU(), | |
nn.Linear(256, output_dim), | |
nn.Tanh() | |
) | |
def forward(self, x): | |
return self.model(x) | |
# Define the Discriminator | |
class Discriminator(nn.Module): | |
def __init__(self, input_dim): | |
super(Discriminator, self).__init__() | |
self.model = nn.Sequential( | |
nn.Linear(input_dim, 256), | |
nn.LeakyReLU(0.2), | |
nn.Linear(256, 128), | |
nn.LeakyReLU(0.2), | |
nn.Linear(128, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
return self.model(x) | |
# Hyperparameters | |
latent_dim = 100 | |
image_dim = 28 * 28 # MNIST images are 28x28 pixels | |
lr = 0.0002 | |
batch_size = 64 | |
# Prepare the data | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) | |
]) | |
dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
# Initialize the models | |
generator = Generator(latent_dim, image_dim) | |
discriminator = Discriminator(image_dim) | |
# Optimizers | |
optimizer_G = optim.Adam(generator.parameters(), lr=lr) | |
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr) | |
# Loss function | |
criterion = nn.BCELoss() | |
# Streamlit interface | |
st.title("GAN with PyTorch and Hugging Face") | |
st.write("Training a GAN to generate MNIST digits") | |
# Slider for epochs | |
epochs = st.slider("Number of Epochs", min_value=1, max_value=100, value=50) | |
train_gan = st.button("Train GAN") | |
if train_gan: | |
# Training loop | |
for epoch in range(epochs): | |
for i, (imgs, _) in enumerate(dataloader): | |
# Prepare real and fake data | |
real_imgs = imgs.view(imgs.size(0), -1) | |
real_labels = torch.ones(imgs.size(0), 1) | |
fake_labels = torch.zeros(imgs.size(0), 1) | |
z = torch.randn(imgs.size(0), latent_dim) | |
fake_imgs = generator(z) | |
# Train Discriminator | |
optimizer_D.zero_grad() | |
real_loss = criterion(discriminator(real_imgs), real_labels) | |
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels) | |
d_loss = real_loss + fake_loss | |
d_loss.backward() | |
optimizer_D.step() | |
# Train Generator | |
optimizer_G.zero_grad() | |
g_loss = criterion(discriminator(fake_imgs), real_labels) | |
g_loss.backward() | |
optimizer_G.step() | |
st.write(f"Epoch [{epoch+1}/{epochs}] | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}") | |
st.write("Training completed") | |
# Generate and display images | |
z = torch.randn(16, latent_dim) | |
generated_imgs = generator(z).view(-1, 1, 28, 28).detach().cpu().numpy() | |
fig, axes = plt.subplots(4, 4, figsize=(8, 8)) | |
for img, ax in zip(generated_imgs, axes.flatten()): | |
ax.imshow(img.reshape(28, 28), cmap="gray") | |
ax.axis('off') | |
st.pyplot(fig) | |
else: | |
st.write("Use the slider to select the number of epochs and click the button to start training the GAN") | |