K00B404's picture
Create app.py
d89262b verified
raw
history blame
4.82 kB
# Define the Pix2Pix model (UNet)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository, create_repo
import gradio as gr
from PIL import Image
import os
# Parameters
IMG_SIZE = 256
BATCH_SIZE = 1
EPOCHS = 12
LR = 0.0002
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define the Pix2Pix model (Simplified UNet)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # 256 -> 128
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16
nn.ReLU(inplace=True),
nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1), # 16 -> 8
nn.ReLU(inplace=True)
)
# Decoder
self.decoder = nn.Sequential(
nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 8 -> 16
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 16 -> 32
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 32 -> 64
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64 -> 128
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), # 128 -> 256
nn.Tanh() # Output range [-1, 1]
)
def forward(self, x):
enc = self.encoder(x)
dec = self.decoder(enc)
return dec
# Training function
def train_model(epochs):
# Load the dataset
ds = load_dataset("K00B404/pix2pix_flux_set")
# Transform function to resize and convert to tensor
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, ds):
self.ds = ds
def __len__(self):
return len(self.ds["train"])
def __getitem__(self, idx):
original = Image.open(self.ds["train"][idx]['original_image']).convert('RGB')
target = Image.open(self.ds["train"][idx]['target_image']).convert('RGB')
return transform(original), transform(target)
dataset = Pix2PixDataset(ds)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model, loss function, and optimizer
model = UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
# Training loop
for epoch in range(epochs):
for i, (original, target) in enumerate(dataloader):
original, target = original.to(device), target.to(device)
optimizer.zero_grad()
# Forward pass
output = model(target)
loss = criterion(output, original)
# Backward pass
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
# Return trained model
return model
# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_name):
repo = Repository(repo_name)
repo.push_to_hub()
# Save the model state dict
model_save_path = os.path.join(repo_name, "pix2pix_model.pth")
torch.save(model.state_dict(), model_save_path)
# Push the model to the repo
repo.push_to_hub(commit_message="Initial commit with trained Pix2Pix model.")
# Gradio interface function
def gradio_train(epochs):
model = train_model(int(epochs))
push_model_to_hub(model, "K00B404/pix2pix_flux")
return f"Model trained for {epochs} epochs and pushed to Hugging Face Hub repository 'K00B404/pix2pix_flux'."
# Gradio Interface
gr_interface = gr.Interface(
fn=gradio_train,
inputs=gr.Number(label="Number of Epochs"),
outputs="text",
title="Pix2Pix Model Training",
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
# Create or clone the repository
create_repo("K00B404/pix2pix_flux", exist_ok=True)
# Launch the Gradio app
gr_interface.launch()