Spaces:
Sleeping
Sleeping
Update utils/i2i.py
Browse files- utils/i2i.py +2 -62
utils/i2i.py
CHANGED
@@ -20,46 +20,6 @@ def seed_everything(seed: int) -> torch.Generator:
|
|
20 |
generator.manual_seed(seed)
|
21 |
return generator
|
22 |
|
23 |
-
# Function to properly handle input image transparency
|
24 |
-
def prepare_transparent_image(input_image):
|
25 |
-
"""
|
26 |
-
Ensures the input image has proper transparency.
|
27 |
-
Converts the image to RGBA if it's not already and enhances partial transparency.
|
28 |
-
"""
|
29 |
-
# Convert to RGBA if not already
|
30 |
-
if input_image.mode != 'RGBA':
|
31 |
-
input_image = input_image.convert('RGBA')
|
32 |
-
|
33 |
-
# Get image data as numpy array
|
34 |
-
img_array = np.array(input_image)
|
35 |
-
|
36 |
-
# Print alpha channel stats for debugging
|
37 |
-
alpha = img_array[:, :, 3]
|
38 |
-
print(f"Input alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, mean: {alpha.mean()}")
|
39 |
-
|
40 |
-
# Check if the image has partial transparency (alpha values between 1-254)
|
41 |
-
if alpha.min() < 255 and np.any((alpha > 0) & (alpha < 255)):
|
42 |
-
print("Detected partial transparency, enhancing alpha channel...")
|
43 |
-
|
44 |
-
# For partially transparent pixels (alpha < 200), make them fully transparent
|
45 |
-
partial_mask = alpha < 200
|
46 |
-
img_array[partial_mask, 3] = 0
|
47 |
-
|
48 |
-
# For mostly opaque pixels (alpha >= 200), make them fully opaque
|
49 |
-
opaque_mask = alpha >= 200
|
50 |
-
img_array[opaque_mask, 3] = 255
|
51 |
-
|
52 |
-
# Create new PIL image with enhanced alpha
|
53 |
-
enhanced_image = Image.fromarray(img_array, 'RGBA')
|
54 |
-
|
55 |
-
# Print updated alpha stats
|
56 |
-
enhanced_alpha = np.array(enhanced_image)[:, :, 3]
|
57 |
-
print(f"Enhanced alpha channel stats - min: {enhanced_alpha.min()}, max: {enhanced_alpha.max()}, mean: {enhanced_alpha.mean()}")
|
58 |
-
|
59 |
-
return enhanced_image
|
60 |
-
|
61 |
-
return input_image
|
62 |
-
|
63 |
# Initialize the pipeline
|
64 |
i2i_pipe = FluxImg2ImgPipeline.from_pretrained(
|
65 |
"black-forest-labs/FLUX.1-dev",
|
@@ -110,9 +70,6 @@ def i2i_gen(
|
|
110 |
gc.collect()
|
111 |
|
112 |
try:
|
113 |
-
# Prepare the input image for proper transparency handling
|
114 |
-
input_image = prepare_transparent_image(input_image)
|
115 |
-
|
116 |
# Process the input image
|
117 |
original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
|
118 |
|
@@ -265,24 +222,7 @@ def i2i_gen(
|
|
265 |
# Convert to image - EXACTLY as in demo_i2i.py
|
266 |
x = x.clamp(0, 1)
|
267 |
x = x.permute(0, 2, 3, 1)
|
268 |
-
|
269 |
-
|
270 |
-
# Ensure the output image has proper transparency
|
271 |
-
if img_array.shape[2] == 4:
|
272 |
-
# Print alpha channel stats for debugging
|
273 |
-
alpha = img_array[:, :, 3]
|
274 |
-
print(f"Output alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, mean: {alpha.mean()}")
|
275 |
-
|
276 |
-
# Make sure partially transparent pixels are handled properly
|
277 |
-
# For partially transparent pixels (alpha < 200), make them fully transparent
|
278 |
-
partial_mask = alpha < 200
|
279 |
-
img_array[partial_mask, 3] = 0
|
280 |
-
|
281 |
-
# For mostly opaque pixels (alpha >= 200), make them fully opaque
|
282 |
-
opaque_mask = alpha >= 200
|
283 |
-
img_array[opaque_mask, 3] = 255
|
284 |
-
|
285 |
-
img = Image.fromarray(img_array, 'RGBA' if img_array.shape[2] == 4 else 'RGB')
|
286 |
|
287 |
# Clean up
|
288 |
del original_x, x
|
@@ -298,4 +238,4 @@ def i2i_gen(
|
|
298 |
traceback.print_exc()
|
299 |
torch.cuda.empty_cache()
|
300 |
gc.collect()
|
301 |
-
return None
|
|
|
20 |
generator.manual_seed(seed)
|
21 |
return generator
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Initialize the pipeline
|
24 |
i2i_pipe = FluxImg2ImgPipeline.from_pretrained(
|
25 |
"black-forest-labs/FLUX.1-dev",
|
|
|
70 |
gc.collect()
|
71 |
|
72 |
try:
|
|
|
|
|
|
|
73 |
# Process the input image
|
74 |
original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
|
75 |
|
|
|
222 |
# Convert to image - EXACTLY as in demo_i2i.py
|
223 |
x = x.clamp(0, 1)
|
224 |
x = x.permute(0, 2, 3, 1)
|
225 |
+
img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
# Clean up
|
228 |
del original_x, x
|
|
|
238 |
traceback.print_exc()
|
239 |
torch.cuda.empty_cache()
|
240 |
gc.collect()
|
241 |
+
return None
|