K00B404 commited on
Commit
803ef66
·
verified ·
1 Parent(s): 1b3e0c4

Update app_bck.py

Browse files
Files changed (1) hide show
  1. app_bck.py +399 -0
app_bck.py CHANGED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ from datasets import load_dataset
7
+ from huggingface_hub import Repository
8
+ from huggingface_hub import HfApi, HfFolder, Repository, create_repo
9
+ import os
10
+ import pandas as pd
11
+ import gradio as gr
12
+ from PIL import Image
13
+ import numpy as np
14
+ from small_256_model import UNet as small_UNet
15
+ from big_1024_model import UNet as big_UNet
16
+ from CLIP import load as load_clip
17
+ from rich import print as rp
18
+
19
+ # Device configuration
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ big = False if device == torch.device('cpu') else True
22
+
23
+ # Parameters
24
+ IMG_SIZE = 1024 if big else 256
25
+ BATCH_SIZE = 1 if big else 1
26
+ EPOCHS = 12
27
+ LR = 0.0002
28
+ dataset_id = "K00B404/pix2pix_flux_set"
29
+ model_repo_id = "K00B404/pix2pix_flux"
30
+
31
+ # Global model variable
32
+ global_model = None
33
+
34
+ # CLIP
35
+ clip_model, clip_tokenizer = load_clip()
36
+
37
+ def load_model():
38
+ """Load the models at startup"""
39
+ global global_model
40
+ weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
41
+ try:
42
+ checkpoint = torch.load(weights_name, map_location=device)
43
+ model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
44
+ model.load_state_dict(checkpoint['model_state_dict'])
45
+ model.to(device)
46
+ model.eval()
47
+ global_model = model
48
+ rp("Model loaded successfully!")
49
+ return model
50
+ except Exception as e:
51
+ rp(f"Error loading model: {e}")
52
+ model = big_UNet().to(device) if big else small_UNet().to(device)
53
+ global_model = model
54
+ return model
55
+
56
+ class Pix2PixDataset(torch.utils.data.Dataset):
57
+ def __init__(self, combined_data, transform, clip_tokenizer):
58
+ self.data = combined_data
59
+ self.transform = transform
60
+ self.clip_tokenizer = clip_tokenizer
61
+ self.original_folder = 'images_dataset/original/'
62
+ self.target_folder = 'images_dataset/target/'
63
+
64
+ def __len__(self):
65
+ return len(self.data)
66
+
67
+ def __getitem__(self, idx):
68
+ original_img_filename = os.path.basename(self.data.iloc[idx]['image_path'])
69
+ original_img_path = os.path.join(self.original_folder, original_img_filename)
70
+ target_img_path = os.path.join(self.target_folder, original_img_filename)
71
+
72
+ original_img = Image.open(original_img_path).convert('RGB')
73
+ target_img = Image.open(target_img_path).convert('RGB')
74
+
75
+ # Transform images
76
+ original = self.transform(original_img)
77
+ target = self.transform(target_img)
78
+
79
+ # Get prompts from the DataFrame
80
+ original_prompt = self.data.iloc[idx]['original_prompt']
81
+ enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
82
+
83
+ # Tokenize the prompts using CLIP tokenizer
84
+ original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
85
+ enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
86
+
87
+ return original, target, original_tokens, enhanced_tokens
88
+
89
+
90
+
91
+ class UNetWrapper:
92
+ def __init__(self, unet_model, repo_id, epoch, loss, optimizer, scheduler=None):
93
+ self.loss = loss
94
+ self.epoch = epoch
95
+ self.model = unet_model
96
+ self.optimizer = optimizer
97
+ self.scheduler = scheduler
98
+ self.repo_id = repo_id
99
+ self.token = os.getenv('NEW_TOKEN') # Ensure the token is set in the environment
100
+ self.api = HfApi(token=self.token)
101
+
102
+ def save_checkpoint(self, save_path):
103
+ """Save checkpoint with model, optimizer, and scheduler states."""
104
+ self.save_dict = {
105
+ 'model_state_dict': self.model.state_dict(),
106
+ 'optimizer_state_dict': self.optimizer.state_dict(),
107
+ 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
108
+ 'model_config': {
109
+ 'big': isinstance(self.model, big_UNet),
110
+ 'img_size': 1024 if isinstance(self.model, big_UNet) else 256
111
+ },
112
+ 'epoch': self.epoch,
113
+ 'loss': self.loss
114
+ }
115
+ torch.save(self.save_dict, save_path)
116
+ print(f"Checkpoint saved at epoch {self.epoch}, loss: {self.loss}")
117
+
118
+ def load_checkpoint(self, checkpoint_path):
119
+ """Load model, optimizer, and scheduler states from the checkpoint."""
120
+ checkpoint = torch.load(checkpoint_path, map_location=device)
121
+ self.model.load_state_dict(checkpoint['model_state_dict'])
122
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
123
+ if self.scheduler and checkpoint['scheduler_state_dict']:
124
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
125
+ self.epoch = checkpoint['epoch']
126
+ self.loss = checkpoint['loss']
127
+ print(f"Checkpoint loaded: epoch {self.epoch}, loss: {self.loss}")
128
+
129
+ def push_to_hub(self, pth_name):
130
+ """Push model checkpoint and metadata to the Hugging Face Hub."""
131
+ try:
132
+ self.api.upload_file(
133
+ path_or_fileobj=pth_name,
134
+ path_in_repo=pth_name,
135
+ repo_id=self.repo_id,
136
+ token=self.token,
137
+ repo_type="model"
138
+ )
139
+ print(f"Model checkpoint successfully uploaded to {self.repo_id}")
140
+ except Exception as e:
141
+ print(f"Error uploading model: {e}")
142
+
143
+
144
+
145
+
146
+ # Create and upload model card
147
+ model_card = f"""---
148
+ tags:
149
+ - unet
150
+ - pix2pix
151
+ - pytorch
152
+ library_name: pytorch
153
+ license: wtfpl
154
+ datasets:
155
+ - K00B404/pix2pix_flux_set
156
+ language:
157
+ - en
158
+ pipeline_tag: image-to-image
159
+ ---
160
+ # Pix2Pix UNet Model
161
+ ## Model Description
162
+ Custom UNet model for Pix2Pix image translation.
163
+ - **Image Size:** {self.save_dict['model_config']['img_size']}
164
+ - **Model Type:** {"big" if big else "small"}_UNet ({self.save_dict['model_config']['img_size']})
165
+ ## Usage
166
+ ```python
167
+ import torch
168
+ from small_256_model import UNet as small_UNet
169
+ from big_1024_model import UNet as big_UNet
170
+ big = True
171
+ # Load the model
172
+ name='big_model_weights.pth' if big else 'small_model_weights.pth'
173
+ checkpoint = torch.load(name)
174
+ model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
175
+ model.load_state_dict(checkpoint['model_state_dict'])
176
+ model.eval()
177
+ ```
178
+ ## Model Architecture
179
+ {str(self.model)} """
180
+ rp(model_card)
181
+ try:
182
+ # Save and upload README
183
+ with open("README.md", "w") as f:
184
+ f.write(f"# Pix2Pix UNet Model\n\n"
185
+ f"- **Image Size:** {self.save_dict['model_config']['img_size']}\n"
186
+ f"- **Model Type:** {'big' if big else 'small'}_UNet ({self.save_dict['model_config']['img_size']})\n"
187
+ f"## Model Architecture\n{str(self.model)}")
188
+
189
+ self.api.upload_file(
190
+ path_or_fileobj="README.md",
191
+ path_in_repo="README.md",
192
+ repo_id=self.repo_id,
193
+ token=self.token,
194
+ repo_type="model"
195
+ )
196
+
197
+ # Clean up local files
198
+ os.remove(pth_name)
199
+ os.remove("README.md")
200
+
201
+ print(f"Model successfully uploaded to {self.repo_id}")
202
+
203
+ except Exception as e:
204
+ print(f"Error uploading model: {e}")
205
+
206
+ def prepare_input(image, device='cpu'):
207
+ """Prepare image for inference"""
208
+ transform = transforms.Compose([
209
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
210
+ transforms.ToTensor(),
211
+ ])
212
+
213
+ if isinstance(image, np.ndarray):
214
+ image = Image.fromarray(image)
215
+ input_tensor = transform(image).unsqueeze(0).to(device)
216
+ return input_tensor
217
+
218
+ def run_inference(image):
219
+ """Run inference on a single image"""
220
+ global global_model
221
+ if global_model is None:
222
+ return "Error: Model not loaded"
223
+
224
+ global_model.eval()
225
+ input_tensor = prepare_input(image, device)
226
+
227
+ with torch.no_grad():
228
+ output = global_model(input_tensor)
229
+
230
+ # Convert output to image
231
+ output = output.cpu().squeeze(0).permute(1, 2, 0).numpy()
232
+ output = ((output - output.min()) / (output.max() - output.min()) * 255).astype(np.uint8)
233
+ rp(output[0])
234
+ return output
235
+
236
+ def to_hub(model, epoch, loss):
237
+ wrapper = UNetWrapper(model, model_repo_id, epoch, loss)
238
+ wrapper.push_to_hub()
239
+
240
+
241
+ def train_model(epochs, save_interval=1):
242
+ """Training function with checkpoint saving and model uploading."""
243
+ global global_model
244
+
245
+ # Load combined data CSV
246
+ data_path = 'combined_data.csv'
247
+ combined_data = pd.read_csv(data_path)
248
+
249
+ # Define the transformation
250
+ transform = transforms.Compose([
251
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
252
+ transforms.ToTensor(),
253
+ ])
254
+
255
+ # Initialize dataset and dataloader
256
+ dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
257
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
258
+
259
+ model = global_model
260
+ criterion = nn.L1Loss()
261
+ optimizer = optim.Adam(model.parameters(), lr=LR)
262
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
263
+ wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
264
+
265
+ output_text = []
266
+
267
+ for epoch in range(epochs):
268
+ model.train()
269
+ running_loss = 0.0
270
+
271
+ for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
272
+ # Move data to device
273
+ original, target = original.to(device), target.to(device)
274
+ original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float()
275
+ enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()
276
+
277
+ optimizer.zero_grad()
278
+
279
+ # Forward pass
280
+ output = model(target)
281
+ img_loss = criterion(output, original)
282
+ total_loss = img_loss
283
+ total_loss.backward()
284
+ optimizer.step()
285
+
286
+ running_loss += total_loss.item()
287
+
288
+ if i % 10 == 0:
289
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
290
+ print(status)
291
+ output_text.append(status)
292
+
293
+ # Update the epoch and loss for checkpoint
294
+ wrapper.epoch = epoch + 1
295
+ wrapper.loss = running_loss / len(dataloader)
296
+
297
+ # Save checkpoint at specified intervals
298
+ if (epoch + 1) % save_interval == 0:
299
+ checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
300
+ wrapper.save_checkpoint(checkpoint_path)
301
+ wrapper.push_to_hub(checkpoint_path)
302
+
303
+ scheduler.step() # Update learning rate scheduler
304
+
305
+ global_model = model # Update global model after training
306
+ return model, "\n".join(output_text)
307
+
308
+
309
+ def train_model_old(epochs):
310
+ """Training function"""
311
+ global global_model
312
+
313
+ # Load combined data CSV
314
+ data_path = 'combined_data.csv' # Adjust this path
315
+ combined_data = pd.read_csv(data_path)
316
+
317
+ # Define the transformation
318
+ transform = transforms.Compose([
319
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
320
+ transforms.ToTensor(),
321
+ ])
322
+
323
+ # Initialize the dataset and dataloader
324
+ dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
325
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
326
+
327
+ model = global_model
328
+ criterion = nn.L1Loss() # L1 loss for image reconstruction
329
+ optimizer = optim.Adam(model.parameters(), lr=LR)
330
+ output_text = []
331
+
332
+ for epoch in range(epochs):
333
+ model.train()
334
+ for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
335
+ # Move images and prompt embeddings to the appropriate device (CPU or GPU)
336
+ original, target = original.to(device), target.to(device)
337
+ original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float
338
+ enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float
339
+
340
+ optimizer.zero_grad()
341
+
342
+ # Forward pass through the model
343
+ output = model(target)
344
+
345
+ # Compute image reconstruction loss
346
+ img_loss = criterion(output, original)
347
+ rp(f"Image {i} Loss:{img_loss}")
348
+
349
+ # Combine losses
350
+ total_loss = img_loss # Add any other losses if necessary
351
+ total_loss.backward()
352
+
353
+ # Optimizer step
354
+ optimizer.step()
355
+
356
+ if i % 10 == 0:
357
+ status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
358
+ rp(status)
359
+ output_text.append(status)
360
+
361
+ # Push model to Hugging Face Hub at the end of each epoch
362
+ to_hub(model, epoch, total_loss)
363
+
364
+ global_model = model # Update the global model after training
365
+ return model, "\n".join(output_text)
366
+
367
+ def gradio_train(epochs):
368
+ # Gradio training interface function
369
+ model, training_log = train_model(int(epochs))
370
+ #to_hub(model)
371
+ return f"{training_log}\n\nModel trained for {epochs} epochs and pushed to {model_repo_id}"
372
+
373
+ def gradio_inference(input_image):
374
+ # Gradio inference interface function
375
+ output_image = run_inference(input_image) # Assuming `run_inference` returns a tuple (output_image, other_data)
376
+ rp(output_image)
377
+ # If `run_inference` returns a tuple, you should only return the image part
378
+ return output_image # Ensure you're only returning the processed output image
379
+
380
+
381
+ # Create Gradio interface with tabs
382
+ with gr.Blocks() as app:
383
+ gr.Markdown("# Pix2Pix Model Training and Inference")
384
+
385
+ with gr.Tab("Train"):
386
+ epochs_input = gr.Number(value=EPOCHS, label="Number of epochs")
387
+ train_button = gr.Button("Train")
388
+ training_output = gr.Textbox(label="Training Log", interactive=False)
389
+ train_button.click(gradio_train, inputs=[epochs_input], outputs=[training_output])
390
+
391
+ with gr.Tab("Inference"):
392
+ image_input = gr.Image(type='numpy')
393
+ prompt_input = gr.Textbox(label="Prompt")
394
+ inference_button = gr.Button("Generate")
395
+ inference_output = gr.Image(type='numpy', label="Generated Image")
396
+ inference_button.click(gradio_inference, inputs=[image_input], outputs=[inference_output])
397
+
398
+ load_model()
399
+ app.launch()