6Morpheus6 commited on
Commit
63a09ce
·
verified ·
1 Parent(s): 1700ed2

Multi image support

Browse files
Files changed (1) hide show
  1. app.py +134 -95
app.py CHANGED
@@ -6,16 +6,16 @@ import torch
6
  import devicetorch
7
  import gradio as gr
8
  import numpy as np
9
- # import spaces
10
  from PIL import Image
11
-
 
12
  from diffusers import FluxKontextPipeline
13
  from diffusers.utils import load_image
14
- from dfloat11 import DFloat11Model
15
 
 
16
  MAX_SEED = np.iinfo(np.int32).max
17
 
18
- pipe = FluxKontextPipeline.from_pretrained("fuliucansheng/FLUX.1-Kontext-dev-diffusers", torch_dtype=torch.bfloat16)
19
  DFloat11Model.from_pretrained(
20
  "DFloat11/FLUX.1-Kontext-dev-DF11",
21
  device="cpu",
@@ -23,69 +23,116 @@ DFloat11Model.from_pretrained(
23
  )
24
  pipe.enable_model_cpu_offload()
25
 
26
- # @spaces.GPU
27
- def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, steps=28, progress=gr.Progress(track_tqdm=True)):
28
  """
29
- Perform image editing using the FLUX.1 Kontext pipeline.
30
-
31
- This function takes an input image and a text prompt to generate a modified version
32
- of the image based on the provided instructions. It uses the FLUX.1 Kontext model
33
- for contextual image editing tasks.
34
 
35
  Args:
36
- input_image (PIL.Image.Image): The input image to be edited. Will be converted
37
- to RGB format if not already in that format.
38
- prompt (str): Text description of the desired edit to apply to the image.
39
- Examples: "Remove glasses", "Add a hat", "Change background to beach".
40
- seed (int, optional): Random seed for reproducible generation. Defaults to 42.
41
- Must be between 0 and MAX_SEED (2^31 - 1).
42
- randomize_seed (bool, optional): If True, generates a random seed instead of
43
- using the provided seed value. Defaults to False.
44
- guidance_scale (float, optional): Controls how closely the model follows the
45
- prompt. Higher values mean stronger adherence to the prompt but may reduce
46
- image quality. Range: 1.0-10.0. Defaults to 2.5.
47
- steps (int, optional): Controls how many steps to run the diffusion model for.
48
- Range: 1-30. Defaults to 28.
49
- progress (gr.Progress, optional): Gradio progress tracker for monitoring
50
- generation progress. Defaults to gr.Progress(track_tqdm=True).
51
 
52
  Returns:
53
- tuple: A 3-tuple containing:
54
- - PIL.Image.Image: The generated/edited image
55
- - int: The seed value used for generation (useful when randomize_seed=True)
56
- - gr.update: Gradio update object to make the reuse button visible
57
-
58
- Example:
59
- >>> edited_image, used_seed, button_update = infer(
60
- ... input_image=my_image,
61
- ... prompt="Add sunglasses",
62
- ... seed=123,
63
- ... randomize_seed=False,
64
- ... guidance_scale=2.5
65
- ... )
66
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  if randomize_seed:
68
  seed = random.randint(0, MAX_SEED)
69
 
70
- if input_image:
71
- input_image = input_image.convert("RGB")
72
- image = pipe(
73
- image=input_image,
74
- prompt=prompt,
75
- guidance_scale=guidance_scale,
76
- width = input_image.size[0],
77
- height = input_image.size[1],
78
- num_inference_steps=steps,
79
- generator=torch.Generator().manual_seed(seed),
80
- ).images[0]
81
- else:
82
- image = pipe(
83
- prompt=prompt,
84
- guidance_scale=guidance_scale,
85
- num_inference_steps=steps,
86
- generator=torch.Generator().manual_seed(seed),
87
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  gradio_temp_dir = os.environ.get('GRADIO_TEMP_DIR', tempfile.gettempdir())
90
  temp_file_path = os.path.join(gradio_temp_dir, "image.png")
91
  image.save(temp_file_path, format="PNG")
@@ -94,14 +141,7 @@ def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5
94
  gc.collect()
95
  devicetorch.empty_cache(torch)
96
 
97
- return image, temp_file_path, seed, gr.Button(visible=True)
98
-
99
- # @spaces.GPU
100
- def infer_example(input_image, prompt):
101
- image, temp_file_path, seed, _ = infer(input_image, prompt)
102
- gc.collect()
103
- devicetorch.empty_cache(torch)
104
- return image,temp_file_path, seed
105
 
106
  css="""
107
  #col-container {
@@ -114,7 +154,6 @@ css="""
114
  #row {
115
  min-height: 40vh; !Important
116
  }
117
-
118
  #row-height {
119
  height: 65px !important
120
  }
@@ -123,17 +162,26 @@ css="""
123
  with gr.Blocks(css=css) as demo:
124
 
125
  with gr.Column(elem_id="col-container"):
126
- gr.Markdown(f"""# FLUX.1 Kontext [dev]
127
- Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro], [[blog]](https://bfl.ai/announcements/flux-1-kontext-dev) [[model]](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)
128
  """)
129
  with gr.Row(equal_height=True):
130
  with gr.Column():
131
- input_image = gr.Image(label="Upload the image for editing", type="pil", elem_classes="input-image", elem_id="row")
132
-
 
 
 
 
 
 
 
 
 
 
133
  with gr.Column():
134
  result = gr.Image(label="Result", show_label=False, interactive=False, elem_classes="input-image", elem_id="row")
135
- reuse_button = gr.Button("Reuse this image", visible=False)
136
-
137
  with gr.Row(equal_height=True):
138
  with gr.Column():
139
  prompt = gr.Text(
@@ -145,14 +193,14 @@ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro
145
  container=True,
146
  scale=1
147
  )
148
-
149
  with gr.Column():
150
- download_image = gr.File(label="Download Image", elem_id="row-height", scale=0)
151
  run_button = gr.Button("Run", scale=1)
152
 
153
  with gr.Row():
154
  with gr.Accordion("Advanced Settings", open=False):
155
-
156
  seed = gr.Slider(
157
  label="Seed",
158
  minimum=0,
@@ -168,39 +216,30 @@ Image editing and manipulation model guidance-distilled from FLUX.1 Kontext [pro
168
  minimum=1,
169
  maximum=10,
170
  step=0.1,
171
- value=2.5,
172
  )
173
 
174
  steps = gr.Slider(
175
  label="Steps",
176
  minimum=1,
177
  maximum=40,
178
- value=28,
179
  step=1
180
- )
181
-
182
- examples = gr.Examples(
183
- examples=[
184
- ["flowers.png", "turn the flowers into sunflowers"],
185
- ["monster.png", "make this monster ride a skateboard on the beach"],
186
- ["cat.png", "make this cat happy"]
187
- ],
188
- inputs=[input_image, prompt],
189
- outputs=[result, download_image, seed],
190
- fn=infer_example,
191
- cache_examples=False
192
- )
193
-
194
  gr.on(
195
  triggers=[run_button.click, prompt.submit],
196
  fn = infer,
197
- inputs = [input_image, prompt, seed, randomize_seed, guidance_scale, steps],
198
- outputs = [result, download_image, seed, reuse_button]
199
  )
 
200
  reuse_button.click(
201
- fn = lambda image: image,
202
  inputs = [result],
203
- outputs = [input_image]
204
  )
205
 
206
  demo.launch(mcp_server=True)
 
6
  import devicetorch
7
  import gradio as gr
8
  import numpy as np
 
9
  from PIL import Image
10
+ from dfloat11 import DFloat11Model
11
+ #from kontext_pipeline import FluxKontextPipeline
12
  from diffusers import FluxKontextPipeline
13
  from diffusers.utils import load_image
 
14
 
15
+ # Load Kontext model
16
  MAX_SEED = np.iinfo(np.int32).max
17
 
18
+ pipe = FluxKontextPipeline.from_pretrained("fuliucansheng/FLUX.1-Kontext-dev-diffusers", torch_dtype=torch.bfloat16).to("cuda")
19
  DFloat11Model.from_pretrained(
20
  "DFloat11/FLUX.1-Kontext-dev-DF11",
21
  device="cpu",
 
23
  )
24
  pipe.enable_model_cpu_offload()
25
 
26
+ def concatenate_images(images, direction="horizontal"):
 
27
  """
28
+ Concatenate multiple PIL images either horizontally or vertically.
 
 
 
 
29
 
30
  Args:
31
+ images: List of PIL Images
32
+ direction: "horizontal" or "vertical"
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  Returns:
35
+ PIL Image: Concatenated image
 
 
 
 
 
 
 
 
 
 
 
 
36
  """
37
+ if not images:
38
+ return None
39
+
40
+ # Filter out None images
41
+ valid_images = [img for img in images if img is not None]
42
+
43
+ if not valid_images:
44
+ return None
45
+
46
+ if len(valid_images) == 1:
47
+ return valid_images[0].convert("RGB")
48
+
49
+ # Convert all images to RGB
50
+ valid_images = [img.convert("RGB") for img in valid_images]
51
+
52
+ if direction == "horizontal":
53
+ # Calculate total width and max height
54
+ total_width = sum(img.width for img in valid_images)
55
+ max_height = max(img.height for img in valid_images)
56
+
57
+ # Create new image
58
+ concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
59
+
60
+ # Paste images
61
+ x_offset = 0
62
+ for img in valid_images:
63
+ # Center image vertically if heights differ
64
+ y_offset = (max_height - img.height) // 2
65
+ concatenated.paste(img, (x_offset, y_offset))
66
+ x_offset += img.width
67
+
68
+ else: # vertical
69
+ # Calculate max width and total height
70
+ max_width = max(img.width for img in valid_images)
71
+ total_height = sum(img.height for img in valid_images)
72
+
73
+ # Create new image
74
+ concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
75
+
76
+ # Paste images
77
+ y_offset = 0
78
+ for img in valid_images:
79
+ # Center image horizontally if widths differ
80
+ x_offset = (max_width - img.width) // 2
81
+ concatenated.paste(img, (x_offset, y_offset))
82
+ y_offset += img.height
83
+
84
+ return concatenated
85
+
86
+ def infer(input_images, prompt, seed=42, randomize_seed=False, guidance_scale=4.0, steps=25, progress=gr.Progress(track_tqdm=True)):
87
+
88
  if randomize_seed:
89
  seed = random.randint(0, MAX_SEED)
90
 
91
+ # Handle input_images - it could be a single image or a list of images
92
+ if input_images is None:
93
+ raise gr.Error("Please upload at least one image.")
94
+
95
+ # If it's a single image (not a list), convert to list
96
+ if not isinstance(input_images, list):
97
+ input_images = [input_images]
98
+
99
+ # Filter out None images
100
+ valid_images = [img[0] for img in input_images if img is not None]
101
+
102
+ if not valid_images:
103
+ raise gr.Error("Please upload at least one valid image.")
104
+
105
+ # Concatenate images horizontally
106
+ concatenated_image = concatenate_images(valid_images, "horizontal")
107
+
108
+ if concatenated_image is None:
109
+ raise gr.Error("Failed to process the input images.")
110
+
111
+ # original_width, original_height = concatenated_image.size
112
+
113
+ # if original_width >= original_height:
114
+ # new_width = 1024
115
+ # new_height = int(original_height * (new_width / original_width))
116
+ # new_height = round(new_height / 64) * 64
117
+ # else:
118
+ # new_height = 1024
119
+ # new_width = int(original_width * (new_height / original_height))
120
+ # new_width = round(new_width / 64) * 64
121
+
122
+ #concatenated_image_resized = concatenated_image.resize((new_width, new_height), Image.LANCZOS)
123
 
124
+ final_prompt = f"From the provided reference images, create a unified, cohesive image such that {prompt}. Maintain the identity and characteristics of each subject while adjusting their proportions, scale, and positioning to create a harmonious, naturally balanced composition. Blend and integrate all elements seamlessly with consistent lighting, perspective, and style.the final result should look like a single naturally captured scene where all subjects are properly sized and positioned relative to each other, not assembled from multiple sources."
125
+
126
+ image = pipe(
127
+ image=concatenated_image,
128
+ prompt=final_prompt,
129
+ guidance_scale=guidance_scale,
130
+ width=concatenated_image.size[0],
131
+ height=concatenated_image.size[1],
132
+ num_inference_steps=steps,
133
+ generator=torch.Generator().manual_seed(seed),
134
+ ).images[0]
135
+
136
  gradio_temp_dir = os.environ.get('GRADIO_TEMP_DIR', tempfile.gettempdir())
137
  temp_file_path = os.path.join(gradio_temp_dir, "image.png")
138
  image.save(temp_file_path, format="PNG")
 
141
  gc.collect()
142
  devicetorch.empty_cache(torch)
143
 
144
+ return image, seed, gr.update(visible=True)
 
 
 
 
 
 
 
145
 
146
  css="""
147
  #col-container {
 
154
  #row {
155
  min-height: 40vh; !Important
156
  }
 
157
  #row-height {
158
  height: 65px !important
159
  }
 
162
  with gr.Blocks(css=css) as demo:
163
 
164
  with gr.Column(elem_id="col-container"):
165
+ gr.Markdown(f"""# FLUX.1 Kontext [dev] - Multi-Image
166
+ Flux Kontext with multiple image input support - compose a new image with elements from multiple images using Kontext [dev]
167
  """)
168
  with gr.Row(equal_height=True):
169
  with gr.Column():
170
+ input_images = gr.Gallery(
171
+ label="Upload image(s) for editing",
172
+ show_label=True,
173
+ elem_id="gallery_input",
174
+ columns=3,
175
+ rows=2,
176
+ object_fit="contain",
177
+ height="auto",
178
+ file_types=['image'],
179
+ type='pil'
180
+ )
181
+
182
  with gr.Column():
183
  result = gr.Image(label="Result", show_label=False, interactive=False, elem_classes="input-image", elem_id="row")
184
+
 
185
  with gr.Row(equal_height=True):
186
  with gr.Column():
187
  prompt = gr.Text(
 
193
  container=True,
194
  scale=1
195
  )
196
+
197
  with gr.Column():
198
+ download_image = gr.File(label="Download Image", elem_id="row-height", interactive=False, scale=0)
199
  run_button = gr.Button("Run", scale=1)
200
 
201
  with gr.Row():
202
  with gr.Accordion("Advanced Settings", open=False):
203
+
204
  seed = gr.Slider(
205
  label="Seed",
206
  minimum=0,
 
216
  minimum=1,
217
  maximum=10,
218
  step=0.1,
219
+ value=4.0,
220
  )
221
 
222
  steps = gr.Slider(
223
  label="Steps",
224
  minimum=1,
225
  maximum=40,
226
+ value=25,
227
  step=1
228
+ )
229
+
230
+ reuse_button = gr.Button("Reuse this image", visible=False)
231
+
 
 
 
 
 
 
 
 
 
 
232
  gr.on(
233
  triggers=[run_button.click, prompt.submit],
234
  fn = infer,
235
+ inputs = [input_images, prompt, seed, randomize_seed, guidance_scale, steps],
236
+ outputs = [result, seed, reuse_button]
237
  )
238
+
239
  reuse_button.click(
240
+ fn = lambda image: [image] if image is not None else [], # Convert single image to list for gallery
241
  inputs = [result],
242
+ outputs = [input_images]
243
  )
244
 
245
  demo.launch(mcp_server=True)