K00B404's picture
Update app.py
9dda4c2 verified
raw
history blame
4.36 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
import gradio as gr
from PIL import Image
import os
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True
# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 1
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"
# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, ds):
# Filter dataset for 'original' and 'target' images
#https://huggingface.co/datasets/K00B404/pix2pix_flux_set/viewer/default/train?f[label][value]=0
#https://huggingface.co/datasets/K00B404/pix2pix_flux_set/viewer/default/train?f[label][value]=1
self.originals = [x for x in ds["train"] if x['label'] == 0]
self.targets = [x for x in ds["train"] if x['label'] == 1]
# Ensure the number of original and target images match
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
# Debug: Print dataset size
print(f"Number of original images: {len(self.originals)}")
print(f"Number of target images: {len(self.targets)}")
def __len__(self):
return len(self.originals)
def __getitem__(self, idx):
# Load original and target images for the given index
original_img = self.originals[idx]['image']
target_img = self.targets[idx]['image']
# Apply the necessary transforms
original = Image.open(original_img).convert('RGB')
target = Image.open(target_img).convert('RGB')
# Return transformed original and target images
return transform(original), transform(target)
# Training function
def train_model(epochs):
# Load the dataset
ds = load_dataset(dataset_id)
# Transform function to resize and convert to tensor
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
dataset = Pix2PixDataset(ds)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model, loss function, and optimizer
try:
model = UNet2DModel.from_pretrained(model_repo_id).to(device)
except Exception:
model = big_UNet().to(device) if big else small_UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
# Training loop
for epoch in range(epochs):
for i, (original, target) in enumerate(dataloader):
original, target = original.to(device), target.to(device)
optimizer.zero_grad()
# Forward pass
output = model(target) # Generate cutout image
loss = criterion(output, original) # Compare with original image
# Backward pass
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
# Return trained model
return model
# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_name):
# Push the model to the Hugging Face hub
model.push_to_hub(repo_name)
# Gradio interface function
def gradio_train(epochs):
model = train_model(int(epochs))
push_model_to_hub(model, model_repo_id)
return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
# Gradio Interface
gr_interface = gr.Interface(
fn=gradio_train,
inputs=gr.Number(label="Number of Epochs"),
outputs="text",
title="Pix2Pix Model Training",
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
# Create or clone the repository if necessary
repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
repo.git_pull()
# Launch the Gradio app
gr_interface.launch()