Spaces:
Sleeping
Sleeping
# Define the Pix2Pix model (UNet) | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from datasets import load_dataset | |
from huggingface_hub import Repository, create_repo | |
import gradio as gr | |
from PIL import Image | |
import os | |
# Parameters | |
IMG_SIZE = 256 | |
BATCH_SIZE = 1 | |
EPOCHS = 12 | |
LR = 0.0002 | |
# Device configuration | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Define the Pix2Pix model (Simplified UNet) | |
class UNet(nn.Module): | |
def __init__(self): | |
super(UNet, self).__init__() | |
# Encoder | |
self.encoder = nn.Sequential( | |
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # 256 -> 128 | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64 | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32 | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16 | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1), # 16 -> 8 | |
nn.ReLU(inplace=True) | |
) | |
# Decoder | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 8 -> 16 | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 16 -> 32 | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 32 -> 64 | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64 -> 128 | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), # 128 -> 256 | |
nn.Tanh() # Output range [-1, 1] | |
) | |
def forward(self, x): | |
enc = self.encoder(x) | |
dec = self.decoder(enc) | |
return dec | |
# Training function | |
def train_model(epochs): | |
# Load the dataset | |
ds = load_dataset("K00B404/pix2pix_flux_set") | |
# Transform function to resize and convert to tensor | |
transform = transforms.Compose([ | |
transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
transforms.ToTensor(), | |
]) | |
# Create dataset and dataloader | |
class Pix2PixDataset(torch.utils.data.Dataset): | |
def __init__(self, ds): | |
self.ds = ds | |
def __len__(self): | |
return len(self.ds["train"]) | |
def __getitem__(self, idx): | |
original = Image.open(self.ds["train"][idx]['original_image']).convert('RGB') | |
target = Image.open(self.ds["train"][idx]['target_image']).convert('RGB') | |
return transform(original), transform(target) | |
dataset = Pix2PixDataset(ds) | |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
# Initialize model, loss function, and optimizer | |
model = UNet().to(device) | |
criterion = nn.L1Loss() | |
optimizer = optim.Adam(model.parameters(), lr=LR) | |
# Training loop | |
for epoch in range(epochs): | |
for i, (original, target) in enumerate(dataloader): | |
original, target = original.to(device), target.to(device) | |
optimizer.zero_grad() | |
# Forward pass | |
output = model(target) | |
loss = criterion(output, original) | |
# Backward pass | |
loss.backward() | |
optimizer.step() | |
if i % 100 == 0: | |
print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}") | |
# Return trained model | |
return model | |
# Push model to Hugging Face Hub | |
def push_model_to_hub(model, repo_name): | |
repo = Repository(repo_name) | |
repo.push_to_hub() | |
# Save the model state dict | |
model_save_path = os.path.join(repo_name, "pix2pix_model.pth") | |
torch.save(model.state_dict(), model_save_path) | |
# Push the model to the repo | |
repo.push_to_hub(commit_message="Initial commit with trained Pix2Pix model.") | |
# Gradio interface function | |
def gradio_train(epochs): | |
model = train_model(int(epochs)) | |
push_model_to_hub(model, "K00B404/pix2pix_flux") | |
return f"Model trained for {epochs} epochs and pushed to Hugging Face Hub repository 'K00B404/pix2pix_flux'." | |
# Gradio Interface | |
gr_interface = gr.Interface( | |
fn=gradio_train, | |
inputs=gr.Number(label="Number of Epochs"), | |
outputs="text", | |
title="Pix2Pix Model Training", | |
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository." | |
) | |
if __name__ == '__main__': | |
# Create or clone the repository | |
create_repo("K00B404/pix2pix_flux", exist_ok=True) | |
# Launch the Gradio app | |
gr_interface.launch() | |