File size: 3,951 Bytes
d89262b
 
 
 
 
 
3010c48
d89262b
 
 
 
a6be944
 
d89262b
 
 
 
3010c48
 
 
 
 
 
 
 
 
d89262b
 
 
 
3010c48
d89262b
 
 
 
 
 
 
 
d626bab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d89262b
d626bab
 
 
d89262b
d626bab
 
d89262b
 
 
 
 
3010c48
 
 
 
 
d89262b
 
 
 
 
 
 
 
 
 
3010c48
 
d89262b
 
 
 
 
 
 
 
 
 
 
 
 
3010c48
 
d89262b
 
 
 
3010c48
 
d89262b
 
 
 
 
 
 
 
 
 
 
3010c48
 
 
 
d89262b
3010c48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import Repository
import gradio as gr
from PIL import Image
import os

from small_256_model import UNet as small_UNet
from big_1024_model import UNet as big_UNet

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

big = False if device == torch.device('cpu') else True

# Parameters
IMG_SIZE = 1024 if big else 256
BATCH_SIZE = 16 if big else 1
EPOCHS = 12
LR = 0.0002
dataset_id = "K00B404/pix2pix_flux_set"
model_repo_id = "K00B404/pix2pix_flux"

# Training function
def train_model(epochs):
    # Load the dataset
    ds = load_dataset(dataset_id)

    # Transform function to resize and convert to tensor
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])

    # Create dataset and dataloader
    # Create dataset and dataloader
class Pix2PixDataset(torch.utils.data.Dataset):
    def __init__(self, ds):
        self.originals = [x for x in ds["train"] if x['label'] == 'original']
        self.targets = [x for x in ds["train"] if x['label'] == 'target']

        # Ensure original and target images match by their index
        assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."

    def __len__(self):
        return len(self.originals)

    def __getitem__(self, idx):
        # Load original and target images for the given index
        original_img = self.originals[idx]['image']
        target_img = self.targets[idx]['image']

        # Apply the necessary transforms
        original = Image.open(original_img).convert('RGB')
        target = Image.open(target_img).convert('RGB')

        # Return transformed original and target images
        return transform(original), transform(target)

    dataset = Pix2PixDataset(ds)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize model, loss function, and optimizer
    try:
        model = UNet2DModel.from_pretrained(model_repo_id).to(device)
    except Exception:
        model = big_UNet().to(device) if big else small_UNet().to(device)

    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=LR)

    # Training loop
    for epoch in range(epochs):
        for i, (original, target) in enumerate(dataloader):
            original, target = original.to(device), target.to(device)
            optimizer.zero_grad()

            # Forward pass
            output = model(target)  # Generate cutout image
            loss = criterion(output, original)  # Compare with original image

            # Backward pass
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")

    # Return trained model
    return model

# Push model to Hugging Face Hub
def push_model_to_hub(model, repo_name):
    # Push the model to the Hugging Face hub
    model.push_to_hub(repo_name)

# Gradio interface function
def gradio_train(epochs):
    model = train_model(int(epochs))
    push_model_to_hub(model, model_repo_id)
    return f"Model trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository."

# Gradio Interface
gr_interface = gr.Interface(
    fn=gradio_train,
    inputs=gr.Number(label="Number of Epochs"),
    outputs="text",
    title="Pix2Pix Model Training",
    description="Train the Pix2Pix model and push it to the Hugging Face Hub repository."
)

if __name__ == '__main__':
    # Create or clone the repository if necessary
    repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id)
    repo.git_pull()

    # Launch the Gradio app
    gr_interface.launch()