blanchon's picture
Update
30bd730
raw
history blame
12 kB
import base64
import io
import os
import zipfile
from io import BytesIO
from pathlib import Path
from typing import Literal, cast
import gradio as gr
import numpy as np
import requests
from gradio.components.image_editor import EditorValue
from PIL import Image
PASSWORD = os.environ.get("PASSWORD", None)
if not PASSWORD:
raise ValueError("PASSWORD is not set")
ENDPOINT = os.environ.get("ENDPOINT", None)
if not ENDPOINT:
raise ValueError("ENDPOINT is not set")
def encode_image_as_base64(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def make_example(image_path: Path, mask_path: Path | None) -> EditorValue:
background_image = Image.open(image_path)
background_image = background_image.convert("RGB")
background = np.array(background_image)
if mask_path:
mask_image = Image.open(mask_path)
mask_image = mask_image.convert("RGB")
mask = np.array(mask_image)
mask = mask[:, :, 0]
mask = np.where(mask == 255, 0, 255) # noqa: PLR2004
else:
mask = np.zeros_like(background)
mask = mask[:, :, 0]
if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]:
msg = "Background and mask must have the same shape"
raise ValueError(msg)
layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8)
layer[:, :, 3] = mask
composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8)
composite[:, :, :3] = background
composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004
return {
"background": background,
"layers": [layer],
"composite": composite,
}
def predict(
model_type: Literal["schnell", "dev", "pixart"],
image_and_mask: EditorValue,
furniture_reference: Image.Image | None,
prompt: str = "",
subfolder: str = "",
seed: int = 0,
num_inference_steps: int = 28,
max_dimension: int = 512,
margin: int = 64,
crop: bool = True,
num_images_per_prompt: int = 1,
) -> list[Image.Image] | None:
if not image_and_mask:
gr.Info("Please upload an image and draw a mask")
return None
if not furniture_reference:
gr.Info("Please upload a furniture reference image")
return None
if model_type == "pixart":
gr.Info("PixArt is not supported yet")
return None
image_np = image_and_mask["background"]
image_np = cast(np.ndarray, image_np)
# If the image is empty, return None
if np.sum(image_np) == 0:
gr.Info("Please upload an image")
return None
alpha_channel = image_and_mask["layers"][0]
alpha_channel = cast(np.ndarray, alpha_channel)
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8)
# if mask_np is empty, return None
if np.sum(mask_np) == 0:
gr.Info("Please mark the areas you want to remove")
return None
mask_image = Image.fromarray(mask_np).convert("L")
target_image = Image.fromarray(image_np).convert("RGB")
# Avoid too big image to be sent to the API
mask_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS)
target_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS)
furniture_reference.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
room_image_input_base64 = encode_image_as_base64(target_image)
room_image_mask_base64 = encode_image_as_base64(mask_image)
furniture_reference_base64 = encode_image_as_base64(furniture_reference)
room_image_input_base64 = "data:image/png;base64," + room_image_input_base64
room_image_mask_base64 = "data:image/png;base64," + room_image_mask_base64
furniture_reference_base64 = "data:image/png;base64," + furniture_reference_base64
response = requests.post(
ENDPOINT,
headers={"accept": "application/json", "Content-Type": "application/json"},
json={
"model_type": model_type,
"room_image_input": room_image_input_base64,
"room_image_mask": room_image_mask_base64,
"furniture_reference_image": furniture_reference_base64,
"prompt": prompt,
"subfolder": subfolder,
"seed": seed,
"num_inference_steps": num_inference_steps,
"max_dimension": max_dimension,
"condition_scale": 1.0,
"margin": margin,
"crop": crop,
"num_images_per_prompt": num_images_per_prompt,
"password": PASSWORD,
},
)
if response.status_code != 200:
gr.Info("An error occurred during the generation")
return None
# Read the returned ZIP file from the response.
zip_bytes = io.BytesIO(response.content)
final_image_list: list[Image.Image] = []
# Open the ZIP archive.
with zipfile.ZipFile(zip_bytes, "r") as zip_file:
image_filenames = zip_file.namelist()
for filename in image_filenames:
with zip_file.open(filename) as file:
image = Image.open(file).convert("RGB")
final_image_list.append(image)
return final_image_list
css = r"""
#col-left {
margin: 0 auto;
max-width: 430px;
}
#col-mid {
margin: 0 auto;
max-width: 430px;
}
#col-right {
margin: 0 auto;
max-width: 430px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("""
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;">
<h1 style="color: #333;">🪑 Furniture Blending Demo</h1>
<div style="max-width: 800px; margin: 0 auto;">
<p style="font-size: 16px;">Upload an image, draw a mask on the areas you want to remove, and upload a furniture reference image.</p>
<p style="font-size: 16px;">
For the best results, make square masks.
Flux dev give better results than the schnell but is slower.
Object reference should be a single object with white background.
</p>
<p style="font-size: 16px;">
You can edit the object with the prompt.
For example, you can add "red couch" to the prompt to make the couch red.
</p>
<br>
<p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo. </p>
</div>
</div>
""")
with gr.Row() as content:
with gr.Column(elem_id="col-left"):
gr.HTML(
r"""
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px">
<div>
🪟 Room image with inpainting mask ⬇️
</div>
</div>
""",
max_height=50,
)
image_and_mask = gr.ImageMask(
label="Image and Mask",
layers=False,
height="full",
width="full",
show_fullscreen_button=False,
sources=["upload"],
show_download_button=False,
interactive=True,
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"),
transforms=[],
)
image_and_mask_examples = gr.Examples(
examples=[
make_example(path, None)
for path in Path("./examples/scenes").glob("*.png")
],
label="Room examples",
examples_per_page=6,
inputs=[image_and_mask],
)
with gr.Column(elem_id="col-mid"):
gr.HTML(
r"""
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px">
<div>
🪑 Furniture reference image ⬇️
</div>
</div>
""",
max_height=50,
)
condition_image = gr.Image(
label="Furniture Reference",
type="pil",
sources=["upload"],
image_mode="RGB",
)
furniture_examples = gr.Examples(
examples=list(Path("./examples/objects").glob("*.png")),
label="Furniture examples",
examples_per_page=6,
inputs=[condition_image],
)
with gr.Column(elem_id="col-right"):
gr.HTML(
r"""
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px">
<div>
🔥 Press Run ⬇️
</div>
</div>
""",
max_height=50,
)
results = gr.Gallery(
label="Result",
format="png",
file_types="image",
show_label=False,
columns=2,
allow_preview=True,
preview=True,
)
model_type = gr.Radio(
choices=["schnell", "dev", "pixart"],
value="schnell",
label="Model Type",
)
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Textbox(
label="Prompt",
value="",
)
subfolder = gr.Textbox(
label="Subfolder",
value="",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=np.iinfo(np.int32).max,
step=1,
value=0,
)
num_images_per_prompt = gr.Slider(
label="Number of images per prompt",
minimum=1,
maximum=10,
step=1,
value=4,
)
crop = gr.Checkbox(
label="Crop",
value=False,
)
margin = gr.Slider(
label="Margin",
minimum=0,
maximum=256,
step=16,
value=128,
)
with gr.Column():
max_dimension = gr.Slider(
label="Max Dimension",
minimum=256,
maximum=1024,
step=128,
value=512,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=4,
maximum=30,
step=2,
value=4,
)
# Change the number of inference steps based on the model type
model_type.change(
fn=lambda x: gr.update(value=4 if x == "schnell" else 28),
inputs=model_type,
outputs=num_inference_steps,
)
run_button.click(
fn=predict,
inputs=[
model_type,
image_and_mask,
condition_image,
prompt,
subfolder,
seed,
num_inference_steps,
max_dimension,
margin,
crop,
num_images_per_prompt,
],
outputs=[results],
)
demo.launch()