Spaces:
Runtime error
Runtime error
Update utils/i2i.py
Browse files- utils/i2i.py +11 -54
utils/i2i.py
CHANGED
@@ -36,7 +36,7 @@ trans_vae.load_state_dict(torch.load("./models/TransparentVAE.pth"), strict=Fals
|
|
36 |
trans_vae.to(device)
|
37 |
|
38 |
# Custom function to safely decode latents
|
39 |
-
def safe_decode(trans_vae, latents
|
40 |
try:
|
41 |
# Standard decoding approach
|
42 |
original_x, x = trans_vae.decode(latents)
|
@@ -73,12 +73,6 @@ def i2i_gen(
|
|
73 |
# Process the input image
|
74 |
original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
|
75 |
|
76 |
-
# Store the original alpha channel for later use
|
77 |
-
original_alpha = None
|
78 |
-
if original_image.shape[1] == 4:
|
79 |
-
original_alpha = original_image[:, 3:4, :, :].clone()
|
80 |
-
print("Original alpha channel detected and saved")
|
81 |
-
|
82 |
# Print original image shape for debugging
|
83 |
print(f"Original image shape: {original_image.shape}")
|
84 |
|
@@ -114,15 +108,11 @@ def i2i_gen(
|
|
114 |
if height != original_image.shape[2] or width != original_image.shape[3]:
|
115 |
print(f"Resizing image from {original_image.shape[2]}x{original_image.shape[3]} to {height}x{width}")
|
116 |
original_image = transforms.functional.resize(original_image, (height, width))
|
117 |
-
|
118 |
-
# Also resize the alpha channel if it exists
|
119 |
-
if original_alpha is not None:
|
120 |
-
original_alpha = transforms.functional.resize(original_alpha, (height, width))
|
121 |
|
122 |
# Print resized image shape for debugging
|
123 |
print(f"Resized image shape: {original_image.shape}")
|
124 |
|
125 |
-
# Prepare the image for processing
|
126 |
padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()]
|
127 |
list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed]
|
128 |
rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(device)
|
@@ -139,22 +129,18 @@ def i2i_gen(
|
|
139 |
alpha = torch.ones((original_image_feed.shape[0], 1, height, width), device=original_image_feed.device)
|
140 |
original_image_feed = torch.cat([original_image_feed, alpha], dim=1)
|
141 |
|
142 |
-
#
|
143 |
-
|
144 |
-
|
145 |
-
# Make sure alpha channel is correctly extracted
|
146 |
-
alpha_channel = original_image_feed[:, 3:4, :, :].clone()
|
147 |
-
|
148 |
-
# Apply alpha to RGB channels
|
149 |
-
original_image_rgb = original_image_feed[:, :3, :, :].clone() * alpha_channel
|
150 |
|
151 |
# Print shape information for debugging
|
152 |
print(f"RGB tensor shape: {original_image_feed[:, :3, :, :].shape}")
|
|
|
153 |
print(f"RGB*alpha tensor shape: {original_image_rgb.shape}")
|
154 |
|
155 |
# Move tensors to device
|
156 |
original_image_feed = original_image_feed.to(device)
|
157 |
original_image_rgb = original_image_rgb.to(device)
|
|
|
158 |
|
159 |
# Verify tensor shapes before encoding
|
160 |
print(f"Before encoding - original_image_feed: {original_image_feed.shape}")
|
@@ -194,7 +180,7 @@ def i2i_gen(
|
|
194 |
raise
|
195 |
|
196 |
# Free up memory
|
197 |
-
del initial_latent
|
198 |
torch.cuda.empty_cache()
|
199 |
|
200 |
# Process the latents
|
@@ -233,42 +219,13 @@ def i2i_gen(
|
|
233 |
del latents
|
234 |
torch.cuda.empty_cache()
|
235 |
|
236 |
-
# Convert to image
|
237 |
x = x.clamp(0, 1)
|
238 |
-
|
239 |
-
|
240 |
-
rgba_tensor = torch.zeros((x.shape[0], 4, x.shape[2], x.shape[3]), device=x.device)
|
241 |
-
|
242 |
-
# Copy the RGB channels
|
243 |
-
rgba_tensor[:, :3, :, :] = x
|
244 |
-
|
245 |
-
# Use the original alpha channel if available, otherwise use the alpha from the model
|
246 |
-
if original_alpha is not None:
|
247 |
-
print("Using original alpha channel for output")
|
248 |
-
# Resize alpha to match output dimensions if needed
|
249 |
-
if original_alpha.shape[2:] != x.shape[2:]:
|
250 |
-
original_alpha = torch.nn.functional.interpolate(
|
251 |
-
original_alpha,
|
252 |
-
size=(x.shape[2], x.shape[3]),
|
253 |
-
mode='bilinear',
|
254 |
-
align_corners=False
|
255 |
-
)
|
256 |
-
rgba_tensor[:, 3:4, :, :] = original_alpha.to(x.device)
|
257 |
-
else:
|
258 |
-
# If no original alpha, create a binary mask based on pixel intensity
|
259 |
-
# Pixels that are nearly black (all channels < 0.05) will be transparent
|
260 |
-
print("Creating alpha mask based on pixel intensity")
|
261 |
-
rgb_sum = x.sum(dim=1, keepdim=True) / 3.0
|
262 |
-
mask = (rgb_sum > 0.05).float()
|
263 |
-
rgba_tensor[:, 3:4, :, :] = mask
|
264 |
-
|
265 |
-
# Convert to PIL image with transparency
|
266 |
-
rgba_tensor = rgba_tensor.permute(0, 2, 3, 1)
|
267 |
-
rgba_array = (rgba_tensor[0] * 255).float().cpu().numpy().astype(np.uint8)
|
268 |
-
img = Image.fromarray(rgba_array, mode='RGBA')
|
269 |
|
270 |
# Clean up
|
271 |
-
del original_x, x
|
272 |
torch.cuda.empty_cache()
|
273 |
gc.collect()
|
274 |
|
|
|
36 |
trans_vae.to(device)
|
37 |
|
38 |
# Custom function to safely decode latents
|
39 |
+
def safe_decode(trans_vae, latents):
|
40 |
try:
|
41 |
# Standard decoding approach
|
42 |
original_x, x = trans_vae.decode(latents)
|
|
|
73 |
# Process the input image
|
74 |
original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# Print original image shape for debugging
|
77 |
print(f"Original image shape: {original_image.shape}")
|
78 |
|
|
|
108 |
if height != original_image.shape[2] or width != original_image.shape[3]:
|
109 |
print(f"Resizing image from {original_image.shape[2]}x{original_image.shape[3]} to {height}x{width}")
|
110 |
original_image = transforms.functional.resize(original_image, (height, width))
|
|
|
|
|
|
|
|
|
111 |
|
112 |
# Print resized image shape for debugging
|
113 |
print(f"Resized image shape: {original_image.shape}")
|
114 |
|
115 |
+
# Prepare the image for processing - EXACTLY as in demo_i2i.py
|
116 |
padding_feed = [x for x in original_image.movedim(1, -1).float().cpu().numpy()]
|
117 |
list_of_np_rgb_padded = [pad_rgb(x) for x in padding_feed]
|
118 |
rgb_padded_bchw_01 = torch.from_numpy(np.stack(list_of_np_rgb_padded, axis=0)).float().movedim(-1, 1).to(device)
|
|
|
129 |
alpha = torch.ones((original_image_feed.shape[0], 1, height, width), device=original_image_feed.device)
|
130 |
original_image_feed = torch.cat([original_image_feed, alpha], dim=1)
|
131 |
|
132 |
+
# Apply alpha to RGB channels - EXACTLY as in demo_i2i.py
|
133 |
+
original_image_rgb = original_image_feed[:, :3, :, :] * original_image_feed[:, 3:4, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
# Print shape information for debugging
|
136 |
print(f"RGB tensor shape: {original_image_feed[:, :3, :, :].shape}")
|
137 |
+
print(f"Alpha channel shape: {original_image_feed[:, 3:4, :, :].shape}")
|
138 |
print(f"RGB*alpha tensor shape: {original_image_rgb.shape}")
|
139 |
|
140 |
# Move tensors to device
|
141 |
original_image_feed = original_image_feed.to(device)
|
142 |
original_image_rgb = original_image_rgb.to(device)
|
143 |
+
rgb_padded_bchw_01 = rgb_padded_bchw_01.to(device)
|
144 |
|
145 |
# Verify tensor shapes before encoding
|
146 |
print(f"Before encoding - original_image_feed: {original_image_feed.shape}")
|
|
|
180 |
raise
|
181 |
|
182 |
# Free up memory
|
183 |
+
del initial_latent, original_image
|
184 |
torch.cuda.empty_cache()
|
185 |
|
186 |
# Process the latents
|
|
|
219 |
del latents
|
220 |
torch.cuda.empty_cache()
|
221 |
|
222 |
+
# Convert to image - EXACTLY as in demo_i2i.py
|
223 |
x = x.clamp(0, 1)
|
224 |
+
x = x.permute(0, 2, 3, 1)
|
225 |
+
img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
# Clean up
|
228 |
+
del original_x, x
|
229 |
torch.cuda.empty_cache()
|
230 |
gc.collect()
|
231 |
|