blanchon commited on
Commit
9e3e526
·
1 Parent(s): 07c4f11
Files changed (3) hide show
  1. README.md +5 -5
  2. app.py +166 -87
  3. uv.lock +0 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: FurnitureAdapter
3
- emoji: 🌖
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
  python_version: 3.12
8
  sdk_version: 5.18.0
@@ -18,6 +18,6 @@ pinned: true
18
  license: mit
19
  ---
20
 
21
- # FurnitureBlendingDemoAPI
22
 
23
  ...
 
1
  ---
2
+ title: FurnitureDemo
3
+ emoji: 🪑
4
+ colorFrom: blue
5
+ colorTo: white
6
  sdk: gradio
7
  python_version: 3.12
8
  sdk_version: 5.18.0
 
18
  license: mit
19
  ---
20
 
21
+ # FurnitureDemo
22
 
23
  ...
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import zipfile
5
  from io import BytesIO
6
  from pathlib import Path
7
- from typing import Literal, cast
8
 
9
  import gradio as gr
10
  import numpy as np
@@ -12,13 +12,23 @@ import requests
12
  from gradio.components.image_editor import EditorValue
13
  from PIL import Image
14
 
15
- PASSWORD = os.environ.get("PASSWORD", None)
16
- if not PASSWORD:
17
- raise ValueError("PASSWORD is not set")
 
 
18
 
19
- ENDPOINT = os.environ.get("ENDPOINT", None)
20
- if not ENDPOINT:
21
- raise ValueError("ENDPOINT is not set")
 
 
 
 
 
 
 
 
22
 
23
 
24
  def encode_image_as_base64(image: Image.Image) -> str:
@@ -60,99 +70,156 @@ def make_example(image_path: Path, mask_path: Path | None) -> EditorValue:
60
  }
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def predict(
64
  model_type: Literal["schnell", "dev", "pixart"],
65
  image_and_mask: EditorValue,
66
  furniture_reference: Image.Image | None,
67
  prompt: str = "",
68
- subfolder: str = "",
69
  seed: int = 0,
70
  num_inference_steps: int = 28,
71
  max_dimension: int = 512,
72
- margin: int = 64,
73
  crop: bool = True,
74
  num_images_per_prompt: int = 1,
75
  ) -> list[Image.Image] | None:
76
- if not image_and_mask:
77
- gr.Info("Please upload an image and draw a mask")
78
- return None
79
- if not furniture_reference:
80
- gr.Info("Please upload a furniture reference image")
81
  return None
82
 
83
  if model_type == "pixart":
84
  gr.Info("PixArt is not supported yet")
85
  return None
86
 
87
- image_np = image_and_mask["background"]
88
- image_np = cast(np.ndarray, image_np)
89
-
90
- # If the image is empty, return None
91
- if np.sum(image_np) == 0:
92
- gr.Info("Please upload an image")
93
- return None
94
-
95
- alpha_channel = image_and_mask["layers"][0]
96
- alpha_channel = cast(np.ndarray, alpha_channel)
97
- mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
98
 
99
- # if mask_np is empty, return None
100
- if np.sum(mask_np) == 0:
101
  gr.Info("Please mark the areas you want to remove")
102
  return None
103
 
104
- mask_image = Image.fromarray(mask_np).convert("L")
105
- target_image = Image.fromarray(image_np).convert("RGB")
106
-
107
- # Avoid too big image to be sent to the API
108
- mask_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS)
109
- target_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS)
110
- furniture_reference.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
111
-
112
- room_image_input_base64 = encode_image_as_base64(target_image)
113
- room_image_mask_base64 = encode_image_as_base64(mask_image)
114
- furniture_reference_base64 = encode_image_as_base64(furniture_reference)
115
-
116
- room_image_input_base64 = "data:image/png;base64," + room_image_input_base64
117
- room_image_mask_base64 = "data:image/png;base64," + room_image_mask_base64
118
- furniture_reference_base64 = "data:image/png;base64," + furniture_reference_base64
119
-
120
- response = requests.post(
121
- ENDPOINT,
122
- headers={"accept": "application/json", "Content-Type": "application/json"},
123
- json={
124
- "model_type": model_type,
125
- "room_image_input": room_image_input_base64,
126
- "room_image_mask": room_image_mask_base64,
127
- "furniture_reference_image": furniture_reference_base64,
128
- "prompt": prompt,
129
- "subfolder": subfolder,
130
- "seed": seed,
131
- "num_inference_steps": num_inference_steps,
132
- "max_dimension": max_dimension,
133
- "condition_scale": 1.0,
134
- "margin": margin,
135
- "crop": crop,
136
- "num_images_per_prompt": num_images_per_prompt,
137
- "password": PASSWORD,
138
- },
139
  )
140
- if response.status_code != 200:
141
- gr.Info("An error occurred during the generation")
142
- return None
143
 
144
- # Read the returned ZIP file from the response.
145
- zip_bytes = io.BytesIO(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- final_image_list: list[Image.Image] = []
 
 
 
 
 
 
 
 
 
 
148
 
149
- # Open the ZIP archive.
150
- with zipfile.ZipFile(zip_bytes, "r") as zip_file:
151
- image_filenames = zip_file.namelist()
152
- for filename in image_filenames:
153
- with zip_file.open(filename) as file:
154
- image = Image.open(file).convert("RGB")
155
- final_image_list.append(image)
 
 
 
 
 
 
156
 
157
  return final_image_list
158
 
@@ -198,7 +265,7 @@ with gr.Blocks(css=css) as demo:
198
  </div>
199
  """)
200
 
201
- with gr.Row() as content:
202
  with gr.Column(elem_id="col-left"):
203
  gr.HTML(
204
  r"""
@@ -219,10 +286,14 @@ with gr.Blocks(css=css) as demo:
219
  sources=["upload"],
220
  show_download_button=False,
221
  interactive=True,
222
- brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
 
 
 
 
223
  transforms=[],
224
  )
225
- image_and_mask_examples = gr.Examples(
226
  examples=[
227
  make_example(path, None)
228
  for path in Path("./examples/scenes").glob("*.png")
@@ -248,7 +319,7 @@ with gr.Blocks(css=css) as demo:
248
  sources=["upload"],
249
  image_mode="RGB",
250
  )
251
- furniture_examples = gr.Examples(
252
  examples=list(Path("./examples/objects").glob("*.png")),
253
  label="Furniture examples",
254
  examples_per_page=6,
@@ -268,7 +339,7 @@ with gr.Blocks(css=css) as demo:
268
  results = gr.Gallery(
269
  label="Result",
270
  format="png",
271
- file_types="image",
272
  show_label=False,
273
  columns=2,
274
  allow_preview=True,
@@ -286,10 +357,6 @@ with gr.Blocks(css=css) as demo:
286
  label="Prompt",
287
  value="",
288
  )
289
- subfolder = gr.Textbox(
290
- label="Subfolder",
291
- value="",
292
- )
293
  seed = gr.Slider(
294
  label="Seed",
295
  minimum=0,
@@ -339,14 +406,23 @@ with gr.Blocks(css=css) as demo:
339
  outputs=num_inference_steps,
340
  )
341
 
 
 
 
 
 
 
 
342
  run_button.click(
 
 
 
343
  fn=predict,
344
  inputs=[
345
  model_type,
346
  image_and_mask,
347
  condition_image,
348
  prompt,
349
- subfolder,
350
  seed,
351
  num_inference_steps,
352
  max_dimension,
@@ -355,7 +431,10 @@ with gr.Blocks(css=css) as demo:
355
  num_images_per_prompt,
356
  ],
357
  outputs=[results],
 
 
 
358
  )
359
 
360
-
361
- demo.launch()
 
4
  import zipfile
5
  from io import BytesIO
6
  from pathlib import Path
7
+ from typing import Literal, TypedDict, cast
8
 
9
  import gradio as gr
10
  import numpy as np
 
12
  from gradio.components.image_editor import EditorValue
13
  from PIL import Image
14
 
15
+ _PASSWORD = os.environ.get("PASSWORD", None)
16
+ if not _PASSWORD:
17
+ msg = "PASSWORD is not set"
18
+ raise ValueError(msg)
19
+ PASSWORD = cast("str", _PASSWORD)
20
 
21
+ _ENDPOINT = os.environ.get("ENDPOINT", None)
22
+ if not _ENDPOINT:
23
+ msg = "ENDPOINT is not set"
24
+ raise ValueError(msg)
25
+ ENDPOINT = cast("str", _ENDPOINT)
26
+
27
+ # Add constants at the top
28
+ THUMBNAIL_MAX_SIZE = 2048
29
+ REFERENCE_MAX_SIZE = 1024
30
+ REQUEST_TIMEOUT = 300 # 5 minutes
31
+ DEFAULT_BRUSH_SIZE = 75
32
 
33
 
34
  def encode_image_as_base64(image: Image.Image) -> str:
 
70
  }
71
 
72
 
73
+ class InputFurnitureBlendingTypedDict(TypedDict):
74
+ return_type: Literal["zipfile", "s3"]
75
+ model_type: Literal["schnell", "dev"]
76
+ room_image_input: str
77
+ bbox: tuple[int, int, int, int]
78
+ furniture_reference_image: str
79
+ prompt: str
80
+ seed: int
81
+ num_inference_steps: int
82
+ max_dimension: int
83
+ margin: int
84
+ crop: bool
85
+ num_images_per_prompt: int
86
+ bucket: str
87
+
88
+
89
+ # Add type hints for the response
90
+ class GenerationResponse(TypedDict):
91
+ images: list[Image.Image]
92
+ error: str | None
93
+
94
+
95
+ def validate_inputs(
96
+ image_and_mask: EditorValue | None,
97
+ furniture_reference: Image.Image | None,
98
+ ) -> tuple[Literal[True], None] | tuple[Literal[False], str]:
99
+ if not image_and_mask:
100
+ return False, "Please upload an image and draw a mask"
101
+
102
+ image_np = cast("np.ndarray", image_and_mask["background"])
103
+ if np.sum(image_np) == 0:
104
+ return False, "Please upload an image"
105
+
106
+ alpha_channel = cast("np.ndarray", image_and_mask["layers"][0])
107
+ mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
108
+ if np.sum(mask_np) == 0:
109
+ return False, "Please mark the areas you want to remove"
110
+
111
+ if not furniture_reference:
112
+ return False, "Please upload a furniture reference image"
113
+
114
+ return True, None
115
+
116
+
117
+ def process_images(
118
+ image_and_mask: EditorValue,
119
+ furniture_reference: Image.Image,
120
+ ) -> tuple[Image.Image, Image.Image, Image.Image]:
121
+ image_np = cast("np.ndarray", image_and_mask["background"])
122
+ alpha_channel = cast("np.ndarray", image_and_mask["layers"][0])
123
+ mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
124
+
125
+ mask_image = Image.fromarray(mask_np).convert("L")
126
+ target_image = Image.fromarray(image_np).convert("RGB")
127
+
128
+ # Resize images
129
+ mask_image.thumbnail(
130
+ (THUMBNAIL_MAX_SIZE, THUMBNAIL_MAX_SIZE), Image.Resampling.LANCZOS
131
+ )
132
+ target_image.thumbnail(
133
+ (THUMBNAIL_MAX_SIZE, THUMBNAIL_MAX_SIZE), Image.Resampling.LANCZOS
134
+ )
135
+ furniture_reference.thumbnail(
136
+ (REFERENCE_MAX_SIZE, REFERENCE_MAX_SIZE), Image.Resampling.LANCZOS
137
+ )
138
+
139
+ return target_image, mask_image, furniture_reference
140
+
141
+
142
  def predict(
143
  model_type: Literal["schnell", "dev", "pixart"],
144
  image_and_mask: EditorValue,
145
  furniture_reference: Image.Image | None,
146
  prompt: str = "",
 
147
  seed: int = 0,
148
  num_inference_steps: int = 28,
149
  max_dimension: int = 512,
150
+ margin: int = 128,
151
  crop: bool = True,
152
  num_images_per_prompt: int = 1,
153
  ) -> list[Image.Image] | None:
154
+ # Validate inputs
155
+ is_valid, error_message = validate_inputs(image_and_mask, furniture_reference)
156
+ if not is_valid and error_message:
157
+ gr.Info(error_message)
 
158
  return None
159
 
160
  if model_type == "pixart":
161
  gr.Info("PixArt is not supported yet")
162
  return None
163
 
164
+ # Process images
165
+ target_image, mask_image, furniture_reference = process_images(
166
+ image_and_mask, cast("Image.Image", furniture_reference)
167
+ )
 
 
 
 
 
 
 
168
 
169
+ bbox = mask_image.getbbox()
170
+ if not bbox:
171
  gr.Info("Please mark the areas you want to remove")
172
  return None
173
 
174
+ # Prepare API request
175
+ room_image_input_base64 = "data:image/png;base64," + encode_image_as_base64(
176
+ target_image
177
+ )
178
+ furniture_reference_base64 = "data:image/png;base64," + encode_image_as_base64(
179
+ furniture_reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
 
 
 
181
 
182
+ body = InputFurnitureBlendingTypedDict(
183
+ return_type="zipfile",
184
+ model_type=model_type,
185
+ room_image_input=room_image_input_base64,
186
+ bbox=bbox,
187
+ furniture_reference_image=furniture_reference_base64,
188
+ prompt=prompt,
189
+ seed=seed,
190
+ num_inference_steps=num_inference_steps,
191
+ max_dimension=max_dimension,
192
+ margin=margin,
193
+ crop=crop,
194
+ num_images_per_prompt=num_images_per_prompt,
195
+ bucket="furniture-blending",
196
+ )
197
 
198
+ try:
199
+ response = requests.post(
200
+ ENDPOINT,
201
+ headers={"accept": "application/json", "Content-Type": "application/json"},
202
+ json=body,
203
+ timeout=REQUEST_TIMEOUT,
204
+ )
205
+ response.raise_for_status()
206
+ except requests.RequestException as e:
207
+ gr.Info(f"API request failed: {e!s}")
208
+ return None
209
 
210
+ # Process response
211
+ try:
212
+ zip_bytes = io.BytesIO(response.content)
213
+ final_image_list: list[Image.Image] = []
214
+
215
+ with zipfile.ZipFile(zip_bytes, "r") as zip_file:
216
+ for filename in zip_file.namelist():
217
+ with zip_file.open(filename) as file:
218
+ image = Image.open(file).convert("RGB")
219
+ final_image_list.append(image)
220
+ except (OSError, zipfile.BadZipFile) as e:
221
+ gr.Info(f"Failed to process response: {e!s}")
222
+ return None
223
 
224
  return final_image_list
225
 
 
265
  </div>
266
  """)
267
 
268
+ with gr.Row():
269
  with gr.Column(elem_id="col-left"):
270
  gr.HTML(
271
  r"""
 
286
  sources=["upload"],
287
  show_download_button=False,
288
  interactive=True,
289
+ brush=gr.Brush(
290
+ default_size=DEFAULT_BRUSH_SIZE,
291
+ colors=["#000000"],
292
+ color_mode="fixed",
293
+ ),
294
  transforms=[],
295
  )
296
+ gr.Examples(
297
  examples=[
298
  make_example(path, None)
299
  for path in Path("./examples/scenes").glob("*.png")
 
319
  sources=["upload"],
320
  image_mode="RGB",
321
  )
322
+ gr.Examples(
323
  examples=list(Path("./examples/objects").glob("*.png")),
324
  label="Furniture examples",
325
  examples_per_page=6,
 
339
  results = gr.Gallery(
340
  label="Result",
341
  format="png",
342
+ file_types=["image"],
343
  show_label=False,
344
  columns=2,
345
  allow_preview=True,
 
357
  label="Prompt",
358
  value="",
359
  )
 
 
 
 
360
  seed = gr.Slider(
361
  label="Seed",
362
  minimum=0,
 
406
  outputs=num_inference_steps,
407
  )
408
 
409
+ # Add loading indicator
410
+ with gr.Row():
411
+ loading_indicator = gr.HTML(
412
+ '<div id="loading" style="display:none;">Processing... Please wait.</div>'
413
+ )
414
+
415
+ # Update click handler to show loading state
416
  run_button.click(
417
+ fn=lambda: gr.update(visible=True),
418
+ outputs=[loading_indicator],
419
+ ).then(
420
  fn=predict,
421
  inputs=[
422
  model_type,
423
  image_and_mask,
424
  condition_image,
425
  prompt,
 
426
  seed,
427
  num_inference_steps,
428
  max_dimension,
 
431
  num_images_per_prompt,
432
  ],
433
  outputs=[results],
434
+ ).then(
435
+ fn=lambda: gr.update(visible=False),
436
+ outputs=[loading_indicator],
437
  )
438
 
439
+ if __name__ == "__main__":
440
+ demo.launch()
uv.lock ADDED
The diff for this file is too large to render. See raw diff