Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
|
|
13 |
import numpy as np
|
14 |
from small_256_model import UNet as small_UNet
|
15 |
from big_1024_model import UNet as big_UNet
|
16 |
-
from CLIP import load as load_clip,load_vae
|
17 |
from rich import print as rp
|
18 |
|
19 |
# Device configuration
|
@@ -56,7 +56,7 @@ def load_model():
|
|
56 |
return model
|
57 |
|
58 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
59 |
-
def __init__(self, combined_data, transform, clip_tokenizer):
|
60 |
self.data = combined_data
|
61 |
self.transform = transform
|
62 |
self.clip_tokenizer = clip_tokenizer
|
@@ -81,12 +81,24 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
81 |
# Get prompts from the DataFrame
|
82 |
original_prompt = self.data.iloc[idx]['original_prompt']
|
83 |
enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
|
|
|
|
|
|
|
|
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Tokenize the prompts using CLIP tokenizer
|
86 |
-
original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
87 |
-
enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
88 |
|
89 |
-
return original, target, original_tokens, enhanced_tokens
|
90 |
|
91 |
|
92 |
|
@@ -261,11 +273,11 @@ def train_model(epochs, save_interval=1):
|
|
261 |
])
|
262 |
|
263 |
# Initialize dataset and dataloader
|
264 |
-
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer)
|
265 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
266 |
|
267 |
model = global_model
|
268 |
-
criterion = nn.L1Loss()
|
269 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
270 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
|
271 |
wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
|
@@ -276,17 +288,17 @@ def train_model(epochs, save_interval=1):
|
|
276 |
model.train()
|
277 |
running_loss = 0.0
|
278 |
|
279 |
-
for i, (
|
280 |
# Move data to device
|
281 |
-
|
282 |
-
|
283 |
-
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float()
|
284 |
-
|
285 |
optimizer.zero_grad()
|
286 |
|
287 |
-
# Forward pass
|
288 |
-
output = model(target
|
289 |
-
|
|
|
|
|
290 |
total_loss = img_loss
|
291 |
total_loss.backward()
|
292 |
optimizer.step()
|
@@ -304,7 +316,7 @@ def train_model(epochs, save_interval=1):
|
|
304 |
|
305 |
# Save checkpoint at specified intervals
|
306 |
if (epoch + 1) % save_interval == 0:
|
307 |
-
checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else
|
308 |
wrapper.save_checkpoint(checkpoint_path)
|
309 |
wrapper.push_to_hub(checkpoint_path)
|
310 |
|
|
|
13 |
import numpy as np
|
14 |
from small_256_model import UNet as small_UNet
|
15 |
from big_1024_model import UNet as big_UNet
|
16 |
+
from CLIP import load as load_clip,load_vae,encode_prompt
|
17 |
from rich import print as rp
|
18 |
|
19 |
# Device configuration
|
|
|
56 |
return model
|
57 |
|
58 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
59 |
+
def __init__(self, combined_data, transform, clip_tokenizer,clip_model):
|
60 |
self.data = combined_data
|
61 |
self.transform = transform
|
62 |
self.clip_tokenizer = clip_tokenizer
|
|
|
81 |
# Get prompts from the DataFrame
|
82 |
original_prompt = self.data.iloc[idx]['original_prompt']
|
83 |
enhanced_prompt = self.data.iloc[idx]['enhanced_prompt']
|
84 |
+
|
85 |
+
# Encode images
|
86 |
+
original_image_latents = vae.encode(original_images).latent_dist.sample()
|
87 |
+
target_image_latents = vae.encode(target_images).latent_dist.sample()
|
88 |
|
89 |
+
# Encode prompts
|
90 |
+
prompt_latents = encode_prompt(enhanced_prompt,clip_model,clip_tokenizer)
|
91 |
+
|
92 |
+
# Pass these to your Pix2Pix model
|
93 |
+
#generated_images = pix2pix_model(original_latents, prompt_latents)
|
94 |
+
|
95 |
+
|
96 |
+
return original_image_latents,target_image_latents,prompt_latents
|
97 |
# Tokenize the prompts using CLIP tokenizer
|
98 |
+
#original_tokens = self.clip_tokenizer(original_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
99 |
+
#enhanced_tokens = self.clip_tokenizer(enhanced_prompt, return_tensors="pt", padding=True, truncation=True, max_length=77)
|
100 |
|
101 |
+
#return original, target, original_tokens, enhanced_tokens
|
102 |
|
103 |
|
104 |
|
|
|
273 |
])
|
274 |
|
275 |
# Initialize dataset and dataloader
|
276 |
+
dataset = Pix2PixDataset(combined_data, transform, clip_tokenizer, clip_model)
|
277 |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
278 |
|
279 |
model = global_model
|
280 |
+
criterion = nn.L1Loss() # You may change this to suit your loss calculation needs
|
281 |
optimizer = optim.Adam(model.parameters(), lr=LR)
|
282 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # Example scheduler
|
283 |
wrapper = UNetWrapper(model, model_repo_id, epoch=0, loss=0.0, optimizer=optimizer, scheduler=scheduler)
|
|
|
288 |
model.train()
|
289 |
running_loss = 0.0
|
290 |
|
291 |
+
for i, (latent_original, latent_target, latent_prompt) in enumerate(dataloader):
|
292 |
# Move data to device
|
293 |
+
latent_original, latent_target, latent_prompt = latent_original.to(device), latent_target.to(device), latent_prompt.to(device)
|
294 |
+
|
|
|
|
|
295 |
optimizer.zero_grad()
|
296 |
|
297 |
+
# Forward pass with the latents
|
298 |
+
output = model(latent_target, latent_prompt) # Assuming your model can take both target and prompt latents
|
299 |
+
|
300 |
+
# Calculate loss using the original latents
|
301 |
+
img_loss = criterion(output, latent_original)
|
302 |
total_loss = img_loss
|
303 |
total_loss.backward()
|
304 |
optimizer.step()
|
|
|
316 |
|
317 |
# Save checkpoint at specified intervals
|
318 |
if (epoch + 1) % save_interval == 0:
|
319 |
+
checkpoint_path = f'big_checkpoint_epoch_{epoch+1}.pth' if big else f'small_checkpoint_epoch_{epoch+1}.pth'
|
320 |
wrapper.save_checkpoint(checkpoint_path)
|
321 |
wrapper.push_to_hub(checkpoint_path)
|
322 |
|