Vijish commited on
Commit
068146a
·
verified ·
1 Parent(s): a1fbcab

Update utils/i2i.py

Browse files
Files changed (1) hide show
  1. utils/i2i.py +62 -2
utils/i2i.py CHANGED
@@ -20,6 +20,46 @@ def seed_everything(seed: int) -> torch.Generator:
20
  generator.manual_seed(seed)
21
  return generator
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Initialize the pipeline
24
  i2i_pipe = FluxImg2ImgPipeline.from_pretrained(
25
  "black-forest-labs/FLUX.1-dev",
@@ -70,6 +110,9 @@ def i2i_gen(
70
  gc.collect()
71
 
72
  try:
 
 
 
73
  # Process the input image
74
  original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
75
 
@@ -222,7 +265,24 @@ def i2i_gen(
222
  # Convert to image - EXACTLY as in demo_i2i.py
223
  x = x.clamp(0, 1)
224
  x = x.permute(0, 2, 3, 1)
225
- img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # Clean up
228
  del original_x, x
@@ -238,4 +298,4 @@ def i2i_gen(
238
  traceback.print_exc()
239
  torch.cuda.empty_cache()
240
  gc.collect()
241
- return None
 
20
  generator.manual_seed(seed)
21
  return generator
22
 
23
+ # Function to properly handle input image transparency
24
+ def prepare_transparent_image(input_image):
25
+ """
26
+ Ensures the input image has proper transparency.
27
+ Converts the image to RGBA if it's not already and enhances partial transparency.
28
+ """
29
+ # Convert to RGBA if not already
30
+ if input_image.mode != 'RGBA':
31
+ input_image = input_image.convert('RGBA')
32
+
33
+ # Get image data as numpy array
34
+ img_array = np.array(input_image)
35
+
36
+ # Print alpha channel stats for debugging
37
+ alpha = img_array[:, :, 3]
38
+ print(f"Input alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, mean: {alpha.mean()}")
39
+
40
+ # Check if the image has partial transparency (alpha values between 1-254)
41
+ if alpha.min() < 255 and np.any((alpha > 0) & (alpha < 255)):
42
+ print("Detected partial transparency, enhancing alpha channel...")
43
+
44
+ # For partially transparent pixels (alpha < 200), make them fully transparent
45
+ partial_mask = alpha < 200
46
+ img_array[partial_mask, 3] = 0
47
+
48
+ # For mostly opaque pixels (alpha >= 200), make them fully opaque
49
+ opaque_mask = alpha >= 200
50
+ img_array[opaque_mask, 3] = 255
51
+
52
+ # Create new PIL image with enhanced alpha
53
+ enhanced_image = Image.fromarray(img_array, 'RGBA')
54
+
55
+ # Print updated alpha stats
56
+ enhanced_alpha = np.array(enhanced_image)[:, :, 3]
57
+ print(f"Enhanced alpha channel stats - min: {enhanced_alpha.min()}, max: {enhanced_alpha.max()}, mean: {enhanced_alpha.mean()}")
58
+
59
+ return enhanced_image
60
+
61
+ return input_image
62
+
63
  # Initialize the pipeline
64
  i2i_pipe = FluxImg2ImgPipeline.from_pretrained(
65
  "black-forest-labs/FLUX.1-dev",
 
110
  gc.collect()
111
 
112
  try:
113
+ # Prepare the input image for proper transparency handling
114
+ input_image = prepare_transparent_image(input_image)
115
+
116
  # Process the input image
117
  original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
118
 
 
265
  # Convert to image - EXACTLY as in demo_i2i.py
266
  x = x.clamp(0, 1)
267
  x = x.permute(0, 2, 3, 1)
268
+ img_array = (x*255).float().cpu().numpy().astype(np.uint8)[0]
269
+
270
+ # Ensure the output image has proper transparency
271
+ if img_array.shape[2] == 4:
272
+ # Print alpha channel stats for debugging
273
+ alpha = img_array[:, :, 3]
274
+ print(f"Output alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, mean: {alpha.mean()}")
275
+
276
+ # Make sure partially transparent pixels are handled properly
277
+ # For partially transparent pixels (alpha < 200), make them fully transparent
278
+ partial_mask = alpha < 200
279
+ img_array[partial_mask, 3] = 0
280
+
281
+ # For mostly opaque pixels (alpha >= 200), make them fully opaque
282
+ opaque_mask = alpha >= 200
283
+ img_array[opaque_mask, 3] = 255
284
+
285
+ img = Image.fromarray(img_array, 'RGBA' if img_array.shape[2] == 4 else 'RGB')
286
 
287
  # Clean up
288
  del original_x, x
 
298
  traceback.print_exc()
299
  torch.cuda.empty_cache()
300
  gc.collect()
301
+ return None