import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_dataset from huggingface_hub import Repository import gradio as gr from PIL import Image import os from small_256_model import UNet as small_UNet from big_1024_model import UNet as big_UNet # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') big = False if device == torch.device('cpu') else True # Parameters IMG_SIZE = 1024 if big else 256 BATCH_SIZE = 16 if big else 1 EPOCHS = 12 LR = 0.0002 dataset_id = "K00B404/pix2pix_flux_set" model_repo_id = "K00B404/pix2pix_flux" # Training function def train_model(epochs): # Load the dataset ds = load_dataset(dataset_id) # Transform function to resize and convert to tensor transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) # Create dataset and dataloader # Create dataset and dataloader class Pix2PixDataset(torch.utils.data.Dataset): def __init__(self, ds): self.originals = [x for x in ds["train"] if x['label'] == 'original'] self.targets = [x for x in ds["train"] if x['label'] == 'target'] # Ensure original and target images match by their index assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images." def __len__(self): return len(self.originals) def __getitem__(self, idx): # Load original and target images for the given index original_img = self.originals[idx]['image'] target_img = self.targets[idx]['image'] # Apply the necessary transforms original = Image.open(original_img).convert('RGB') target = Image.open(target_img).convert('RGB') # Return transformed original and target images return transform(original), transform(target) dataset = Pix2PixDataset(ds) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # Initialize model, loss function, and optimizer try: model = UNet2DModel.from_pretrained(model_repo_id).to(device) except Exception: model = big_UNet().to(device) if big else small_UNet().to(device) criterion = nn.L1Loss() optimizer = optim.Adam(model.parameters(), lr=LR) # Training loop for epoch in range(epochs): for i, (original, target) in enumerate(dataloader): original, target = original.to(device), target.to(device) optimizer.zero_grad() # Forward pass output = model(target) # Generate cutout image loss = criterion(output, original) # Compare with original image # Backward pass loss.backward() optimizer.step() if i % 100 == 0: print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}") # Return trained model return model # Push model to Hugging Face Hub def push_model_to_hub(model, repo_name): # Push the model to the Hugging Face hub model.push_to_hub(repo_name) # Gradio interface function def gradio_train(epochs): model = train_model(int(epochs)) push_model_to_hub(model, model_repo_id) return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository." # Gradio Interface gr_interface = gr.Interface( fn=gradio_train, inputs=gr.Number(label="Number of Epochs"), outputs="text", title="Pix2Pix Model Training", description="Train the Pix2Pix model and push it to the Hugging Face Hub repository." ) if __name__ == '__main__': # Create or clone the repository if necessary repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id) repo.git_pull() # Launch the Gradio app gr_interface.launch()