import base64 import io import os import zipfile from io import BytesIO from pathlib import Path from typing import Literal, cast import gradio as gr import numpy as np import requests from gradio.components.image_editor import EditorValue from PIL import Image PASSWORD = os.environ.get("PASSWORD", None) if not PASSWORD: raise ValueError("PASSWORD is not set") ENDPOINT = os.environ.get("ENDPOINT", None) if not ENDPOINT: raise ValueError("ENDPOINT is not set") def encode_image_as_base64(image: Image.Image) -> str: buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def make_example(image_path: Path, mask_path: Path | None) -> EditorValue: background_image = Image.open(image_path) background_image = background_image.convert("RGB") background = np.array(background_image) if mask_path: mask_image = Image.open(mask_path) mask_image = mask_image.convert("RGB") mask = np.array(mask_image) mask = mask[:, :, 0] mask = np.where(mask == 255, 0, 255) # noqa: PLR2004 else: mask = np.zeros_like(background) mask = mask[:, :, 0] if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]: msg = "Background and mask must have the same shape" raise ValueError(msg) layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) layer[:, :, 3] = mask composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) composite[:, :, :3] = background composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004 return { "background": background, "layers": [layer], "composite": composite, } def predict( model_type: Literal["schnell", "dev", "pixart"], image_and_mask: EditorValue, furniture_reference: Image.Image | None, prompt: str = "", subfolder: str = "", seed: int = 0, num_inference_steps: int = 28, max_dimension: int = 512, margin: int = 64, crop: bool = True, num_images_per_prompt: int = 1, ) -> list[Image.Image] | None: if not image_and_mask: gr.Info("Please upload an image and draw a mask") return None if not furniture_reference: gr.Info("Please upload a furniture reference image") return None if model_type == "pixart": gr.Info("PixArt is not supported yet") return None image_np = image_and_mask["background"] image_np = cast(np.ndarray, image_np) # If the image is empty, return None if np.sum(image_np) == 0: gr.Info("Please upload an image") return None alpha_channel = image_and_mask["layers"][0] alpha_channel = cast(np.ndarray, alpha_channel) mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) # if mask_np is empty, return None if np.sum(mask_np) == 0: gr.Info("Please mark the areas you want to remove") return None mask_image = Image.fromarray(mask_np).convert("L") target_image = Image.fromarray(image_np).convert("RGB") # Avoid too big image to be sent to the API mask_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS) target_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS) furniture_reference.thumbnail((1024, 1024), Image.Resampling.LANCZOS) room_image_input_base64 = encode_image_as_base64(target_image) room_image_mask_base64 = encode_image_as_base64(mask_image) furniture_reference_base64 = encode_image_as_base64(furniture_reference) room_image_input_base64 = "data:image/png;base64," + room_image_input_base64 room_image_mask_base64 = "data:image/png;base64," + room_image_mask_base64 furniture_reference_base64 = "data:image/png;base64," + furniture_reference_base64 response = requests.post( ENDPOINT, headers={"accept": "application/json", "Content-Type": "application/json"}, json={ "model_type": model_type, "room_image_input": room_image_input_base64, "room_image_mask": room_image_mask_base64, "furniture_reference_image": furniture_reference_base64, "prompt": prompt, "subfolder": subfolder, "seed": seed, "num_inference_steps": num_inference_steps, "max_dimension": max_dimension, "condition_scale": 1.0, "margin": margin, "crop": crop, "num_images_per_prompt": num_images_per_prompt, "password": PASSWORD, }, ) if response.status_code != 200: gr.Info("An error occurred during the generation") return None # Read the returned ZIP file from the response. zip_bytes = io.BytesIO(response.content) final_image_list: list[Image.Image] = [] # Open the ZIP archive. with zipfile.ZipFile(zip_bytes, "r") as zip_file: image_filenames = zip_file.namelist() for filename in image_filenames: with zip_file.open(filename) as file: image = Image.open(file).convert("RGB") final_image_list.append(image) return final_image_list css = r""" #col-left { margin: 0 auto; max-width: 430px; } #col-mid { margin: 0 auto; max-width: 430px; } #col-right { margin: 0 auto; max-width: 430px; } #col-showcase { margin: 0 auto; max-width: 1100px; } """ with gr.Blocks(css=css) as demo: gr.HTML("""

🪑 Furniture Blending Demo

Upload an image, draw a mask on the areas you want to remove, and upload a furniture reference image.

For the best results, make square masks. Flux dev give better results than the schnell but is slower. Object reference should be a single object with white background.

You can edit the object with the prompt. For example, you can add "red couch" to the prompt to make the couch red.


⚠️ Note that the images are compressed to reduce the workloads of the demo.

""") with gr.Row() as content: with gr.Column(elem_id="col-left"): gr.HTML( r"""
🪟 Room image with inpainting mask ⬇️
""", max_height=50, ) image_and_mask = gr.ImageMask( label="Image and Mask", layers=False, height="full", width="full", show_fullscreen_button=False, sources=["upload"], show_download_button=False, interactive=True, brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), transforms=[], ) image_and_mask_examples = gr.Examples( examples=[ make_example(path, None) for path in Path("./examples/scenes").glob("*.png") ], label="Room examples", examples_per_page=6, inputs=[image_and_mask], ) with gr.Column(elem_id="col-mid"): gr.HTML( r"""
🪑 Furniture reference image ⬇️
""", max_height=50, ) condition_image = gr.Image( label="Furniture Reference", type="pil", sources=["upload"], image_mode="RGB", ) furniture_examples = gr.Examples( examples=list(Path("./examples/objects").glob("*.png")), label="Furniture examples", examples_per_page=6, inputs=[condition_image], ) with gr.Column(elem_id="col-right"): gr.HTML( r"""
🔥 Press Run ⬇️
""", max_height=50, ) results = gr.Gallery( label="Result", format="png", file_types="image", show_label=False, columns=2, allow_preview=True, preview=True, ) model_type = gr.Radio( choices=["schnell", "dev", "pixart"], value="schnell", label="Model Type", ) run_button = gr.Button("Run") with gr.Accordion("Advanced Settings", open=False): prompt = gr.Textbox( label="Prompt", value="", ) subfolder = gr.Textbox( label="Subfolder", value="", ) seed = gr.Slider( label="Seed", minimum=0, maximum=np.iinfo(np.int32).max, step=1, value=0, ) num_images_per_prompt = gr.Slider( label="Number of images per prompt", minimum=1, maximum=10, step=1, value=4, ) crop = gr.Checkbox( label="Crop", value=False, ) margin = gr.Slider( label="Margin", minimum=0, maximum=256, step=16, value=128, ) with gr.Column(): max_dimension = gr.Slider( label="Max Dimension", minimum=256, maximum=1024, step=128, value=512, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=4, maximum=30, step=2, value=4, ) # Change the number of inference steps based on the model type model_type.change( fn=lambda x: gr.update(value=4 if x == "schnell" else 28), inputs=model_type, outputs=num_inference_steps, ) run_button.click( fn=predict, inputs=[ model_type, image_and_mask, condition_image, prompt, subfolder, seed, num_inference_steps, max_dimension, margin, crop, num_images_per_prompt, ], outputs=[results], ) demo.launch()