Spaces:
Running
Running
File size: 4,146 Bytes
d89262b 3010c48 d89262b a6be944 d89262b 3010c48 d89262b cc04230 d626bab cc04230 d626bab cc04230 d626bab cc04230 d626bab d89262b d626bab d89262b d626bab 5010115 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
import gradio as gr
from PIL import Image
import os
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True
# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 1
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"
# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, ds):
# Filter dataset for 'original' and 'target' images
self.originals = [x for x in ds["train"] if x['label'] == 'original']
self.targets = [x for x in ds["train"] if x['label'] == 'target']
# Ensure the number of original and target images match
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
# Debug: Print dataset size
print(f"Number of original images: {len(self.originals)}")
print(f"Number of target images: {len(self.targets)}")
def __len__(self):
return len(self.originals)
def __getitem__(self, idx):
# Load original and target images for the given index
original_img = self.originals[idx]['image']
target_img = self.targets[idx]['image']
# Apply the necessary transforms
original = Image.open(original_img).convert('RGB')
target = Image.open(target_img).convert('RGB')
# Return transformed original and target images
return transform(original), transform(target)
# Training function
def train_model(epochs):
# Load the dataset
ds = load_dataset(dataset_id)
# Transform function to resize and convert to tensor
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
dataset = Pix2PixDataset(ds)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model, loss function, and optimizer
try:
model = UNet2DModel.from_pretrained(model_repo_id).to(device)
except Exception:
model = big_UNet().to(device) if big else small_UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
# Training loop
for epoch in range(epochs):
for i, (original, target) in enumerate(dataloader):
original, target = original.to(device), target.to(device)
optimizer.zero_grad()
# Forward pass
output = model(target) # Generate cutout image
loss = criterion(output, original) # Compare with original image
# Backward pass
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
# Return trained model
return model
# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_name):
# Push the model to the Hugging Face hub
model.push_to_hub(repo_name)
# Gradio interface function
def gradio_train(epochs):
model = train_model(int(epochs))
push_model_to_hub(model, model_repo_id)
return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
# Gradio Interface
gr_interface = gr.Interface(
fn=gradio_train,
inputs=gr.Number(label="Number of Epochs"),
outputs="text",
title="Pix2Pix Model Training",
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
# Create or clone the repository if necessary
repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
repo.git_pull()
# Launch the Gradio app
gr_interface.launch() |