K00B404 commited on
Commit
6e15e32
·
verified ·
1 Parent(s): 9adfffc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -16
app.py CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
13
  import numpy as np
14
  from small_256_model import UNet as small_UNet
15
  from big_1024_model import UNet as big_UNet
16
- from CLIP import load as load_clip,load_vae
17
  from rich import print as rp
18
 
19
  # Device configuration
@@ -56,7 +56,7 @@ def load_model():
56
  return model
57
 
58
  class Pix2PixDataset(torch.utils.data.Dataset):
59
- def __init__(self, combined_data, transform, clip_tokenizer):
60
  self.data = combined_data
61
  self.transform = transform
62
  self.clip_tokenizer = clip_tokenizer
@@ -81,12 +81,24 @@ class Pix2PixDataset(torch.utils.data.Dataset):
81
  # Get prompts from the DataFrame
82
  original_prompt = self.data.iloc[idx]['original_prompt']
83
  enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
 
 
 
 
84
 
 
 
 
 
 
 
 
 
85
  # Tokenize the prompts using CLIP tokenizer
86
- original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
87
- enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
88
 
89
- return original, target, original_tokens, enhanced_tokens
90
 
91
 
92
 
@@ -261,11 +273,11 @@ def train_model(epochs, save_interval=1):
261
  ])
262
 
263
  # Initialize dataset and dataloader
264
- dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
265
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
266
 
267
  model = global_model
268
- criterion = nn.L1Loss()
269
  optimizer = optim.Adam(model.parameters(), lr=LR)
270
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
271
  wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
@@ -276,17 +288,17 @@ def train_model(epochs, save_interval=1):
276
  model.train()
277
  running_loss = 0.0
278
 
279
- for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
280
  # Move data to device
281
- original, target = original.to(device), target.to(device)
282
- original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float()
283
- enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()
284
-
285
  optimizer.zero_grad()
286
 
287
- # Forward pass
288
- output = model(target)
289
- img_loss = criterion(output, original)
 
 
290
  total_loss = img_loss
291
  total_loss.backward()
292
  optimizer.step()
@@ -304,7 +316,7 @@ def train_model(epochs, save_interval=1):
304
 
305
  # Save checkpoint at specified intervals
306
  if (epoch + 1) % save_interval == 0:
307
- checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
308
  wrapper.save_checkpoint(checkpoint_path)
309
  wrapper.push_to_hub(checkpoint_path)
310
 
 
13
  import numpy as np
14
  from small_256_model import UNet as small_UNet
15
  from big_1024_model import UNet as big_UNet
16
+ from CLIP import load as load_clip,load_vae,encode_prompt
17
  from rich import print as rp
18
 
19
  # Device configuration
 
56
  return model
57
 
58
  class Pix2PixDataset(torch.utils.data.Dataset):
59
+ def __init__(self, combined_data, transform, clip_tokenizer,clip_model):
60
  self.data = combined_data
61
  self.transform = transform
62
  self.clip_tokenizer = clip_tokenizer
 
81
  # Get prompts from the DataFrame
82
  original_prompt = self.data.iloc[idx]['original_prompt']
83
  enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
84
+
85
+ # Encode images
86
+ original_image_latents = vae.encode(original_images).latent_dist.sample()
87
+ target_image_latents = vae.encode(target_images).latent_dist.sample()
88
 
89
+ # Encode prompts
90
+ prompt_latents = encode_prompt(enhanced_prompt,clip_model,clip_tokenizer)
91
+
92
+ # Pass these to your Pix2Pix model
93
+ #generated_images = pix2pix_model(original_latents, prompt_latents)
94
+
95
+
96
+ return original_image_latents,target_image_latents,prompt_latents
97
  # Tokenize the prompts using CLIP tokenizer
98
+ #original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
99
+ #enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
100
 
101
+ #return original, target, original_tokens, enhanced_tokens
102
 
103
 
104
 
 
273
  ])
274
 
275
  # Initialize dataset and dataloader
276
+ dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer, clip_model)
277
  dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
278
 
279
  model = global_model
280
+ criterion = nn.L1Loss() # You may change this to suit your loss calculation needs
281
  optimizer = optim.Adam(model.parameters(), lr=LR)
282
  scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
283
  wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
 
288
  model.train()
289
  running_loss = 0.0
290
 
291
+ for i, (latent_original, latent_target, latent_prompt) in enumerate(dataloader):
292
  # Move data to device
293
+ latent_original, latent_target, latent_prompt = latent_original.to(device), latent_target.to(device), latent_prompt.to(device)
294
+
 
 
295
  optimizer.zero_grad()
296
 
297
+ # Forward pass with the latents
298
+ output = model(latent_target, latent_prompt) # Assuming your model can take both target and prompt latents
299
+
300
+ # Calculate loss using the original latents
301
+ img_loss = criterion(output, latent_original)
302
  total_loss = img_loss
303
  total_loss.backward()
304
  optimizer.step()
 
316
 
317
  # Save checkpoint at specified intervals
318
  if (epoch + 1) % save_interval == 0:
319
+ checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
320
  wrapper.save_checkpoint(checkpoint_path)
321
  wrapper.push_to_hub(checkpoint_path)
322