Vijish commited on
Commit
105147f
·
verified ·
1 Parent(s): 859bed9

Update utils/i2i.py

Browse files
Files changed (1) hide show
  1. utils/i2i.py +11 -54
utils/i2i.py CHANGED
@@ -36,7 +36,7 @@ trans_vae.load_state_dict(torch.load("./models/TransparentVAE.pth"), strict=Fals
36
  trans_vae.to(device)
37
 
38
  # Custom function to safely decode latents
39
- def safe_decode(trans_vae, latents, original_alpha=None):
40
  try:
41
  # Standard decoding approach
42
  original_x, x = trans_vae.decode(latents)
@@ -73,12 +73,6 @@ def i2i_gen(
73
  # Process the input image
74
  original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
75
 
76
- # Store the original alpha channel for later use
77
- original_alpha = None
78
- if original_image.shape[1] == 4:
79
- original_alpha = original_image[:, 3:4, :, :].clone()
80
- print("Original alpha channel detected and saved")
81
-
82
  # Print original image shape for debugging
83
  print(f"Original image shape: {original_image.shape}")
84
 
@@ -114,15 +108,11 @@ def i2i_gen(
114
  if height != original_image.shape[2] or width != original_image.shape[3]:
115
  print(f"Resizing image from {original_image.shape[2]}x{original_image.shape[3]} to {height}x{width}")
116
  original_image = transforms.functional.resize(original_image, (height, width))
117
-
118
- # Also resize the alpha channel if it exists
119
- if original_alpha is not None:
120
- original_alpha = transforms.functional.resize(original_alpha, (height, width))
121
 
122
  # Print resized image shape for debugging
123
  print(f"Resized image shape: {original_image.shape}")
124
 
125
- # Prepare the image for processing
126
  padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()]
127
  list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed]
128
  rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(device)
@@ -139,22 +129,18 @@ def i2i_gen(
139
  alpha = torch.ones((original_image_feed.shape[0], 1, height, width), device=original_image_feed.device)
140
  original_image_feed = torch.cat([original_image_feed, alpha], dim=1)
141
 
142
- # Print alpha channel shape for debugging
143
- print(f"Alpha channel shape: {original_image_feed[:, 3:4, :, :].shape}")
144
-
145
- # Make sure alpha channel is correctly extracted
146
- alpha_channel = original_image_feed[:, 3:4, :, :].clone()
147
-
148
- # Apply alpha to RGB channels
149
- original_image_rgb = original_image_feed[:, :3, :, :].clone() * alpha_channel
150
 
151
  # Print shape information for debugging
152
  print(f"RGB tensor shape: {original_image_feed[:, :3, :, :].shape}")
 
153
  print(f"RGB*alpha tensor shape: {original_image_rgb.shape}")
154
 
155
  # Move tensors to device
156
  original_image_feed = original_image_feed.to(device)
157
  original_image_rgb = original_image_rgb.to(device)
 
158
 
159
  # Verify tensor shapes before encoding
160
  print(f"Before encoding - original_image_feed: {original_image_feed.shape}")
@@ -194,7 +180,7 @@ def i2i_gen(
194
  raise
195
 
196
  # Free up memory
197
- del initial_latent
198
  torch.cuda.empty_cache()
199
 
200
  # Process the latents
@@ -233,42 +219,13 @@ def i2i_gen(
233
  del latents
234
  torch.cuda.empty_cache()
235
 
236
- # Convert to image with transparency
237
  x = x.clamp(0, 1)
238
-
239
- # Create a new tensor with 4 channels (RGBA)
240
- rgba_tensor = torch.zeros((x.shape[0], 4, x.shape[2], x.shape[3]), device=x.device)
241
-
242
- # Copy the RGB channels
243
- rgba_tensor[:, :3, :, :] = x
244
-
245
- # Use the original alpha channel if available, otherwise use the alpha from the model
246
- if original_alpha is not None:
247
- print("Using original alpha channel for output")
248
- # Resize alpha to match output dimensions if needed
249
- if original_alpha.shape[2:] != x.shape[2:]:
250
- original_alpha = torch.nn.functional.interpolate(
251
- original_alpha,
252
- size=(x.shape[2], x.shape[3]),
253
- mode='bilinear',
254
- align_corners=False
255
- )
256
- rgba_tensor[:, 3:4, :, :] = original_alpha.to(x.device)
257
- else:
258
- # If no original alpha, create a binary mask based on pixel intensity
259
- # Pixels that are nearly black (all channels < 0.05) will be transparent
260
- print("Creating alpha mask based on pixel intensity")
261
- rgb_sum = x.sum(dim=1, keepdim=True) / 3.0
262
- mask = (rgb_sum > 0.05).float()
263
- rgba_tensor[:, 3:4, :, :] = mask
264
-
265
- # Convert to PIL image with transparency
266
- rgba_tensor = rgba_tensor.permute(0, 2, 3, 1)
267
- rgba_array = (rgba_tensor[0] * 255).float().cpu().numpy().astype(np.uint8)
268
- img = Image.fromarray(rgba_array, mode='RGBA')
269
 
270
  # Clean up
271
- del original_x, x, rgba_tensor, original_image
272
  torch.cuda.empty_cache()
273
  gc.collect()
274
 
 
36
  trans_vae.to(device)
37
 
38
  # Custom function to safely decode latents
39
+ def safe_decode(trans_vae, latents):
40
  try:
41
  # Standard decoding approach
42
  original_x, x = trans_vae.decode(latents)
 
73
  # Process the input image
74
  original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
75
 
 
 
 
 
 
 
76
  # Print original image shape for debugging
77
  print(f"Original image shape: {original_image.shape}")
78
 
 
108
  if height != original_image.shape[2] or width != original_image.shape[3]:
109
  print(f"Resizing image from {original_image.shape[2]}x{original_image.shape[3]} to {height}x{width}")
110
  original_image = transforms.functional.resize(original_image, (height, width))
 
 
 
 
111
 
112
  # Print resized image shape for debugging
113
  print(f"Resized image shape: {original_image.shape}")
114
 
115
+ # Prepare the image for processing - EXACTLY as in demo_i2i.py
116
  padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()]
117
  list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed]
118
  rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(device)
 
129
  alpha = torch.ones((original_image_feed.shape[0], 1, height, width), device=original_image_feed.device)
130
  original_image_feed = torch.cat([original_image_feed, alpha], dim=1)
131
 
132
+ # Apply alpha to RGB channels - EXACTLY as in demo_i2i.py
133
+ original_image_rgb = original_image_feed[:, :3, :, :] * original_image_feed[:, 3:4, :, :]
 
 
 
 
 
 
134
 
135
  # Print shape information for debugging
136
  print(f"RGB tensor shape: {original_image_feed[:, :3, :, :].shape}")
137
+ print(f"Alpha channel shape: {original_image_feed[:, 3:4, :, :].shape}")
138
  print(f"RGB*alpha tensor shape: {original_image_rgb.shape}")
139
 
140
  # Move tensors to device
141
  original_image_feed = original_image_feed.to(device)
142
  original_image_rgb = original_image_rgb.to(device)
143
+ rgb_padded_bchw_01 = rgb_padded_bchw_01.to(device)
144
 
145
  # Verify tensor shapes before encoding
146
  print(f"Before encoding - original_image_feed: {original_image_feed.shape}")
 
180
  raise
181
 
182
  # Free up memory
183
+ del initial_latent, original_image
184
  torch.cuda.empty_cache()
185
 
186
  # Process the latents
 
219
  del latents
220
  torch.cuda.empty_cache()
221
 
222
+ # Convert to image - EXACTLY as in demo_i2i.py
223
  x = x.clamp(0, 1)
224
+ x = x.permute(0, 2, 3, 1)
225
+ img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # Clean up
228
+ del original_x, x
229
  torch.cuda.empty_cache()
230
  gc.collect()
231