Spaces:
Running
Running
File size: 3,951 Bytes
d89262b 3010c48 d89262b a6be944 d89262b 3010c48 d89262b 3010c48 d89262b d626bab d89262b d626bab d89262b d626bab d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 d89262b 3010c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
import gradio as gr
from PIL import Image
import os
from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
big = False if device == torch.device('cpu') else True
# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 1
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"
# Training function
def train_model(epochs):
# Load the dataset
ds = load_dataset(dataset_id)
# Transform function to resize and convert to tensor
transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
# Create dataset and dataloader
# Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
def __init__(self, ds):
self.originals = [x for x in ds["train"] if x['label'] == 'original']
self.targets = [x for x in ds["train"] if x['label'] == 'target']
# Ensure original and target images match by their index
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
def __len__(self):
return len(self.originals)
def __getitem__(self, idx):
# Load original and target images for the given index
original_img = self.originals[idx]['image']
target_img = self.targets[idx]['image']
# Apply the necessary transforms
original = Image.open(original_img).convert('RGB')
target = Image.open(target_img).convert('RGB')
# Return transformed original and target images
return transform(original), transform(target)
dataset = Pix2PixDataset(ds)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model, loss function, and optimizer
try:
model = UNet2DModel.from_pretrained(model_repo_id).to(device)
except Exception:
model = big_UNet().to(device) if big else small_UNet().to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LR)
# Training loop
for epoch in range(epochs):
for i, (original, target) in enumerate(dataloader):
original, target = original.to(device), target.to(device)
optimizer.zero_grad()
# Forward pass
output = model(target) # Generate cutout image
loss = criterion(output, original) # Compare with original image
# Backward pass
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")
# Return trained model
return model
# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_name):
# Push the model to the Hugging Face hub
model.push_to_hub(repo_name)
# Gradio interface function
def gradio_train(epochs):
model = train_model(int(epochs))
push_model_to_hub(model, model_repo_id)
return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."
# Gradio Interface
gr_interface = gr.Interface(
fn=gradio_train,
inputs=gr.Number(label="Number of Epochs"),
outputs="text",
title="Pix2Pix Model Training",
description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)
if __name__ == '__main__':
# Create or clone the repository if necessary
repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
repo.git_pull()
# Launch the Gradio app
gr_interface.launch() |