Spaces:
Running
Running
import base64 | |
import io | |
import os | |
import zipfile | |
from io import BytesIO | |
from pathlib import Path | |
from typing import Literal, TypedDict, cast | |
import gradio as gr | |
import numpy as np | |
import requests | |
from gradio.components.image_editor import EditorValue | |
from PIL import Image | |
_PASSWORD = os.environ.get("PASSWORD", None) | |
if not _PASSWORD: | |
msg = "PASSWORD is not set" | |
raise ValueError(msg) | |
PASSWORD = cast("str", _PASSWORD) | |
_ENDPOINT = os.environ.get("ENDPOINT", None) | |
if not _ENDPOINT: | |
msg = "ENDPOINT is not set" | |
raise ValueError(msg) | |
ENDPOINT = cast("str", _ENDPOINT) | |
# Add constants at the top | |
THUMBNAIL_MAX_SIZE = 2048 | |
REFERENCE_MAX_SIZE = 1024 | |
REQUEST_TIMEOUT = 300 # 5 minutes | |
DEFAULT_BRUSH_SIZE = 75 | |
def encode_image_as_base64(image: Image.Image) -> str: | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def make_example(image_path: Path, mask_path: Path | None) -> EditorValue: | |
background_image = Image.open(image_path) | |
background_image = background_image.convert("RGB") | |
background = np.array(background_image) | |
if mask_path: | |
mask_image = Image.open(mask_path) | |
mask_image = mask_image.convert("RGB") | |
mask = np.array(mask_image) | |
mask = mask[:, :, 0] | |
mask = np.where(mask == 255, 0, 255) # noqa: PLR2004 | |
else: | |
mask = np.zeros_like(background) | |
mask = mask[:, :, 0] | |
if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]: | |
msg = "Background and mask must have the same shape" | |
raise ValueError(msg) | |
layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) | |
layer[:, :, 3] = mask | |
composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) | |
composite[:, :, :3] = background | |
composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004 | |
return { | |
"background": background, | |
"layers": [layer], | |
"composite": composite, | |
} | |
class InputFurnitureBlendingTypedDict(TypedDict): | |
return_type: Literal["zipfile", "s3"] | |
model_type: Literal["schnell", "dev"] | |
room_image_input: str | |
bbox: tuple[int, int, int, int] | |
furniture_reference_image: str | |
prompt: str | |
seed: int | |
num_inference_steps: int | |
max_dimension: int | |
margin: int | |
crop: bool | |
num_images_per_prompt: int | |
bucket: str | |
# Add type hints for the response | |
class GenerationResponse(TypedDict): | |
images: list[Image.Image] | |
error: str | None | |
def validate_inputs( | |
image_and_mask: EditorValue | None, | |
furniture_reference: Image.Image | None, | |
) -> tuple[Literal[True], None] | tuple[Literal[False], str]: | |
if not image_and_mask: | |
return False, "Please upload an image and draw a mask" | |
image_np = cast("np.ndarray", image_and_mask["background"]) | |
if np.sum(image_np) == 0: | |
return False, "Please upload an image" | |
alpha_channel = cast("np.ndarray", image_and_mask["layers"][0]) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
if np.sum(mask_np) == 0: | |
return False, "Please mark the areas you want to remove" | |
if not furniture_reference: | |
return False, "Please upload a furniture reference image" | |
return True, None | |
def process_images( | |
image_and_mask: EditorValue, | |
furniture_reference: Image.Image, | |
) -> tuple[Image.Image, Image.Image, Image.Image]: | |
image_np = cast("np.ndarray", image_and_mask["background"]) | |
alpha_channel = cast("np.ndarray", image_and_mask["layers"][0]) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
mask_image = Image.fromarray(mask_np).convert("L") | |
target_image = Image.fromarray(image_np).convert("RGB") | |
# Resize images | |
mask_image.thumbnail( | |
(THUMBNAIL_MAX_SIZE, THUMBNAIL_MAX_SIZE), Image.Resampling.LANCZOS | |
) | |
target_image.thumbnail( | |
(THUMBNAIL_MAX_SIZE, THUMBNAIL_MAX_SIZE), Image.Resampling.LANCZOS | |
) | |
furniture_reference.thumbnail( | |
(REFERENCE_MAX_SIZE, REFERENCE_MAX_SIZE), Image.Resampling.LANCZOS | |
) | |
return target_image, mask_image, furniture_reference | |
def predict( | |
model_type: Literal["schnell", "dev", "pixart"], | |
image_and_mask: EditorValue, | |
furniture_reference: Image.Image | None, | |
prompt: str = "", | |
seed: int = 0, | |
num_inference_steps: int = 28, | |
max_dimension: int = 512, | |
margin: int = 128, | |
crop: bool = True, | |
num_images_per_prompt: int = 1, | |
) -> list[Image.Image] | None: | |
# Validate inputs | |
is_valid, error_message = validate_inputs(image_and_mask, furniture_reference) | |
if not is_valid and error_message: | |
gr.Info(error_message) | |
return None | |
if model_type == "pixart": | |
gr.Info("PixArt is not supported yet") | |
return None | |
# Process images | |
target_image, mask_image, furniture_reference = process_images( | |
image_and_mask, cast("Image.Image", furniture_reference) | |
) | |
bbox = mask_image.getbbox() | |
if not bbox: | |
gr.Info("Please mark the areas you want to remove") | |
return None | |
# Prepare API request | |
room_image_input_base64 = "data:image/png;base64," + encode_image_as_base64( | |
target_image | |
) | |
furniture_reference_base64 = "data:image/png;base64," + encode_image_as_base64( | |
furniture_reference | |
) | |
body = InputFurnitureBlendingTypedDict( | |
return_type="zipfile", | |
model_type=model_type, | |
room_image_input=room_image_input_base64, | |
bbox=bbox, | |
furniture_reference_image=furniture_reference_base64, | |
prompt=prompt, | |
seed=seed, | |
num_inference_steps=num_inference_steps, | |
max_dimension=max_dimension, | |
margin=margin, | |
crop=crop, | |
num_images_per_prompt=num_images_per_prompt, | |
bucket="furniture-blending", | |
) | |
try: | |
response = requests.post( | |
ENDPOINT, | |
headers={"accept": "application/json", "Content-Type": "application/json"}, | |
json=body, | |
timeout=REQUEST_TIMEOUT, | |
) | |
response.raise_for_status() | |
except requests.RequestException as e: | |
gr.Info(f"API request failed: {e!s}") | |
return None | |
# Process response | |
try: | |
zip_bytes = io.BytesIO(response.content) | |
final_image_list: list[Image.Image] = [] | |
with zipfile.ZipFile(zip_bytes, "r") as zip_file: | |
for filename in zip_file.namelist(): | |
with zip_file.open(filename) as file: | |
image = Image.open(file).convert("RGB") | |
final_image_list.append(image) | |
except (OSError, zipfile.BadZipFile) as e: | |
gr.Info(f"Failed to process response: {e!s}") | |
return None | |
return final_image_list | |
css = r""" | |
#col-left { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-mid { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-right { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-showcase { | |
margin: 0 auto; | |
max-width: 1100px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
<h1 style="color: #333;">🪑 Furniture Blending Demo</h1> | |
<div style="max-width: 800px; margin: 0 auto;"> | |
<p style="font-size: 16px;">Upload an image, draw a mask on the areas you want to remove, and upload a furniture reference image.</p> | |
<p style="font-size: 16px;"> | |
For the best results, make square masks. | |
Flux dev give better results than the schnell but is slower. | |
Object reference should be a single object with white background. | |
</p> | |
<p style="font-size: 16px;"> | |
You can edit the object with the prompt. | |
For example, you can add "red couch" to the prompt to make the couch red. | |
</p> | |
<br> | |
<p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo. </p> | |
</div> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(elem_id="col-left"): | |
gr.HTML( | |
r""" | |
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
<div> | |
🪟 Room image with inpainting mask ⬇️ | |
</div> | |
</div> | |
""", | |
max_height=50, | |
) | |
image_and_mask = gr.ImageMask( | |
label="Image and Mask", | |
layers=False, | |
height="full", | |
width="full", | |
show_fullscreen_button=False, | |
sources=["upload"], | |
show_download_button=False, | |
interactive=True, | |
brush=gr.Brush( | |
default_size=DEFAULT_BRUSH_SIZE, | |
colors=["#000000"], | |
color_mode="fixed", | |
), | |
transforms=[], | |
) | |
gr.Examples( | |
examples=[ | |
make_example(path, None) | |
for path in Path("./examples/scenes").glob("*.png") | |
], | |
label="Room examples", | |
examples_per_page=6, | |
inputs=[image_and_mask], | |
) | |
with gr.Column(elem_id="col-mid"): | |
gr.HTML( | |
r""" | |
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
<div> | |
🪑 Furniture reference image ⬇️ | |
</div> | |
</div> | |
""", | |
max_height=50, | |
) | |
condition_image = gr.Image( | |
label="Furniture Reference", | |
type="pil", | |
sources=["upload"], | |
image_mode="RGB", | |
) | |
gr.Examples( | |
examples=list(Path("./examples/objects").glob("*.png")), | |
label="Furniture examples", | |
examples_per_page=6, | |
inputs=[condition_image], | |
) | |
with gr.Column(elem_id="col-right"): | |
gr.HTML( | |
r""" | |
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
<div> | |
🔥 Press Run ⬇️ | |
</div> | |
</div> | |
""", | |
max_height=50, | |
) | |
results = gr.Gallery( | |
label="Result", | |
format="png", | |
file_types=["image"], | |
show_label=False, | |
columns=2, | |
allow_preview=True, | |
preview=True, | |
) | |
model_type = gr.Radio( | |
choices=["schnell", "dev", "pixart"], | |
value="dev", | |
label="Model Type", | |
) | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced Settings", open=False): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=np.iinfo(np.int32).max, | |
step=1, | |
value=0, | |
) | |
num_images_per_prompt = gr.Slider( | |
label="Number of images per prompt", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=2, | |
) | |
crop = gr.Checkbox( | |
label="Crop", | |
value=False, | |
) | |
margin = gr.Slider( | |
label="Margin", | |
minimum=0, | |
maximum=256, | |
step=16, | |
value=128, | |
) | |
with gr.Column(): | |
max_dimension = gr.Slider( | |
label="Max Dimension", | |
minimum=256, | |
maximum=1024, | |
step=128, | |
value=512, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=4, | |
maximum=30, | |
step=2, | |
value=28, | |
) | |
# Change the number of inference steps based on the model type | |
model_type.change( | |
fn=lambda x: gr.update(value=4 if x == "schnell" else 28), | |
inputs=model_type, | |
outputs=num_inference_steps, | |
) | |
# Add loading indicator | |
with gr.Row(): | |
loading_indicator = gr.HTML( | |
'<div id="loading" style="display:none;">Processing... Please wait.</div>' | |
) | |
# Update click handler to show loading state | |
run_button.click( | |
fn=lambda: gr.update(visible=True), | |
outputs=[loading_indicator], | |
).then( | |
fn=predict, | |
inputs=[ | |
model_type, | |
image_and_mask, | |
condition_image, | |
prompt, | |
seed, | |
num_inference_steps, | |
max_dimension, | |
margin, | |
crop, | |
num_images_per_prompt, | |
], | |
outputs=[results], | |
).then( | |
fn=lambda: gr.update(visible=False), | |
outputs=[loading_indicator], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |