# Define the Pix2Pix model (UNet) import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_dataset from huggingface_hub import Repository, create_repo import gradio as gr from PIL import Image import os # Parameters IMG_SIZE = 256 BATCH_SIZE = 1 EPOCHS = 12 LR = 0.0002 # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define the Pix2Pix model (Simplified UNet) class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() # Encoder self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), # 256 -> 128 nn.ReLU(inplace=True), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128 -> 64 nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 64 -> 32 nn.ReLU(inplace=True), nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), # 32 -> 16 nn.ReLU(inplace=True), nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1), # 16 -> 8 nn.ReLU(inplace=True) ) # Decoder self.decoder = nn.Sequential( nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1), # 8 -> 16 nn.ReLU(inplace=True), nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1), # 16 -> 32 nn.ReLU(inplace=True), nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 32 -> 64 nn.ReLU(inplace=True), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64 -> 128 nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), # 128 -> 256 nn.Tanh() # Output range [-1, 1] ) def forward(self, x): enc = self.encoder(x) dec = self.decoder(enc) return dec # Training function def train_model(epochs): # Load the dataset ds = load_dataset("K00B404/pix2pix_flux_set") # Transform function to resize and convert to tensor transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) # Create dataset and dataloader class Pix2PixDataset(torch.utils.data.Dataset): def __init__(self, ds): self.ds = ds def __len__(self): return len(self.ds["train"]) def __getitem__(self, idx): original = Image.open(self.ds["train"][idx]['original_image']).convert('RGB') target = Image.open(self.ds["train"][idx]['target_image']).convert('RGB') return transform(original), transform(target) dataset = Pix2PixDataset(ds) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) # Initialize model, loss function, and optimizer model = UNet().to(device) criterion = nn.L1Loss() optimizer = optim.Adam(model.parameters(), lr=LR) # Training loop for epoch in range(epochs): for i, (original, target) in enumerate(dataloader): original, target = original.to(device), target.to(device) optimizer.zero_grad() # Forward pass output = model(target) loss = criterion(output, original) # Backward pass loss.backward() optimizer.step() if i % 100 == 0: print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}") # Return trained model return model # Push model to Hugging Face Hub def push_model_to_hub(model, repo_name): repo = Repository(repo_name) repo.push_to_hub() # Save the model state dict model_save_path = os.path.join(repo_name, "pix2pix_model.pth") torch.save(model.state_dict(), model_save_path) # Push the model to the repo repo.push_to_hub(commit_message="Initial commit with trained Pix2Pix model.") # Gradio interface function def gradio_train(epochs): model = train_model(int(epochs)) push_model_to_hub(model, "K00B404/pix2pix_flux") return f"Model trained for {epochs} epochs and pushed to Hugging Face Hub repository 'K00B404/pix2pix_flux'." # Gradio Interface gr_interface = gr.Interface( fn=gradio_train, inputs=gr.Number(label="Number of Epochs"), outputs="text", title="Pix2Pix Model Training", description="Train the Pix2Pix model and push it to the Hugging Face Hub repository." ) if __name__ == '__main__': # Create or clone the repository create_repo("K00B404/pix2pix_flux", exist_ok=True) # Launch the Gradio app gr_interface.launch()