Vijish commited on
Commit
d02e046
·
verified ·
1 Parent(s): 068146a

Update utils/i2i.py

Browse files
Files changed (1) hide show
  1. utils/i2i.py +2 -62
utils/i2i.py CHANGED
@@ -20,46 +20,6 @@ def seed_everything(seed: int) -> torch.Generator:
20
  generator.manual_seed(seed)
21
  return generator
22
 
23
- # Function to properly handle input image transparency
24
- def prepare_transparent_image(input_image):
25
- """
26
- Ensures the input image has proper transparency.
27
- Converts the image to RGBA if it's not already and enhances partial transparency.
28
- """
29
- # Convert to RGBA if not already
30
- if input_image.mode != 'RGBA':
31
- input_image = input_image.convert('RGBA')
32
-
33
- # Get image data as numpy array
34
- img_array = np.array(input_image)
35
-
36
- # Print alpha channel stats for debugging
37
- alpha = img_array[:, :, 3]
38
- print(f"Input alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, mean: {alpha.mean()}")
39
-
40
- # Check if the image has partial transparency (alpha values between 1-254)
41
- if alpha.min() < 255 and np.any((alpha > 0) & (alpha < 255)):
42
- print("Detected partial transparency, enhancing alpha channel...")
43
-
44
- # For partially transparent pixels (alpha < 200), make them fully transparent
45
- partial_mask = alpha < 200
46
- img_array[partial_mask, 3] = 0
47
-
48
- # For mostly opaque pixels (alpha >= 200), make them fully opaque
49
- opaque_mask = alpha >= 200
50
- img_array[opaque_mask, 3] = 255
51
-
52
- # Create new PIL image with enhanced alpha
53
- enhanced_image = Image.fromarray(img_array, 'RGBA')
54
-
55
- # Print updated alpha stats
56
- enhanced_alpha = np.array(enhanced_image)[:, :, 3]
57
- print(f"Enhanced alpha channel stats - min: {enhanced_alpha.min()}, max: {enhanced_alpha.max()}, mean: {enhanced_alpha.mean()}")
58
-
59
- return enhanced_image
60
-
61
- return input_image
62
-
63
  # Initialize the pipeline
64
  i2i_pipe = FluxImg2ImgPipeline.from_pretrained(
65
  "black-forest-labs/FLUX.1-dev",
@@ -110,9 +70,6 @@ def i2i_gen(
110
  gc.collect()
111
 
112
  try:
113
- # Prepare the input image for proper transparency handling
114
- input_image = prepare_transparent_image(input_image)
115
-
116
  # Process the input image
117
  original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
118
 
@@ -265,24 +222,7 @@ def i2i_gen(
265
  # Convert to image - EXACTLY as in demo_i2i.py
266
  x = x.clamp(0, 1)
267
  x = x.permute(0, 2, 3, 1)
268
- img_array = (x*255).float().cpu().numpy().astype(np.uint8)[0]
269
-
270
- # Ensure the output image has proper transparency
271
- if img_array.shape[2] == 4:
272
- # Print alpha channel stats for debugging
273
- alpha = img_array[:, :, 3]
274
- print(f"Output alpha channel stats - min: {alpha.min()}, max: {alpha.max()}, mean: {alpha.mean()}")
275
-
276
- # Make sure partially transparent pixels are handled properly
277
- # For partially transparent pixels (alpha < 200), make them fully transparent
278
- partial_mask = alpha < 200
279
- img_array[partial_mask, 3] = 0
280
-
281
- # For mostly opaque pixels (alpha >= 200), make them fully opaque
282
- opaque_mask = alpha >= 200
283
- img_array[opaque_mask, 3] = 255
284
-
285
- img = Image.fromarray(img_array, 'RGBA' if img_array.shape[2] == 4 else 'RGB')
286
 
287
  # Clean up
288
  del original_x, x
@@ -298,4 +238,4 @@ def i2i_gen(
298
  traceback.print_exc()
299
  torch.cuda.empty_cache()
300
  gc.collect()
301
- return None
 
20
  generator.manual_seed(seed)
21
  return generator
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Initialize the pipeline
24
  i2i_pipe = FluxImg2ImgPipeline.from_pretrained(
25
  "black-forest-labs/FLUX.1-dev",
 
70
  gc.collect()
71
 
72
  try:
 
 
 
73
  # Process the input image
74
  original_image = (transforms.ToTensor()(input_image)).unsqueeze(0)
75
 
 
222
  # Convert to image - EXACTLY as in demo_i2i.py
223
  x = x.clamp(0, 1)
224
  x = x.permute(0, 2, 3, 1)
225
+ img = Image.fromarray((x*255).float().cpu().numpy().astype(np.uint8)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # Clean up
228
  del original_x, x
 
238
  traceback.print_exc()
239
  torch.cuda.empty_cache()
240
  gc.collect()
241
+ return None