Spaces:
Running
Running
import base64 | |
import io | |
import os | |
import zipfile | |
from io import BytesIO | |
from pathlib import Path | |
from typing import Literal, cast | |
import gradio as gr | |
import numpy as np | |
import requests | |
from gradio.components.image_editor import EditorValue | |
from PIL import Image | |
PASSWORD = os.environ.get("PASSWORD", None) | |
if not PASSWORD: | |
raise ValueError("PASSWORD is not set") | |
ENDPOINT = os.environ.get("ENDPOINT", None) | |
if not ENDPOINT: | |
raise ValueError("ENDPOINT is not set") | |
def encode_image_as_base64(image: Image.Image) -> str: | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
def make_example(image_path: Path, mask_path: Path | None) -> EditorValue: | |
background_image = Image.open(image_path) | |
background_image = background_image.convert("RGB") | |
background = np.array(background_image) | |
if mask_path: | |
mask_image = Image.open(mask_path) | |
mask_image = mask_image.convert("RGB") | |
mask = np.array(mask_image) | |
mask = mask[:, :, 0] | |
mask = np.where(mask == 255, 0, 255) # noqa: PLR2004 | |
else: | |
mask = np.zeros_like(background) | |
mask = mask[:, :, 0] | |
if background.shape[0] != mask.shape[0] or background.shape[1] != mask.shape[1]: | |
msg = "Background and mask must have the same shape" | |
raise ValueError(msg) | |
layer = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) | |
layer[:, :, 3] = mask | |
composite = np.zeros((background.shape[0], background.shape[1], 4), dtype=np.uint8) | |
composite[:, :, :3] = background | |
composite[:, :, 3] = np.where(mask == 255, 0, 255) # noqa: PLR2004 | |
return { | |
"background": background, | |
"layers": [layer], | |
"composite": composite, | |
} | |
def predict( | |
model_type: Literal["schnell", "dev", "pixart"], | |
image_and_mask: EditorValue, | |
furniture_reference: Image.Image | None, | |
prompt: str = "", | |
subfolder: str = "", | |
seed: int = 0, | |
num_inference_steps: int = 28, | |
max_dimension: int = 512, | |
margin: int = 64, | |
crop: bool = True, | |
num_images_per_prompt: int = 1, | |
) -> list[Image.Image] | None: | |
if not image_and_mask: | |
gr.Info("Please upload an image and draw a mask") | |
return None | |
if not furniture_reference: | |
gr.Info("Please upload a furniture reference image") | |
return None | |
if model_type == "pixart": | |
gr.Info("PixArt is not supported yet") | |
return None | |
image_np = image_and_mask["background"] | |
image_np = cast(np.ndarray, image_np) | |
# If the image is empty, return None | |
if np.sum(image_np) == 0: | |
gr.Info("Please upload an image") | |
return None | |
alpha_channel = image_and_mask["layers"][0] | |
alpha_channel = cast(np.ndarray, alpha_channel) | |
mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) | |
# if mask_np is empty, return None | |
if np.sum(mask_np) == 0: | |
gr.Info("Please mark the areas you want to remove") | |
return None | |
mask_image = Image.fromarray(mask_np).convert("L") | |
target_image = Image.fromarray(image_np).convert("RGB") | |
# Avoid too big image to be sent to the API | |
mask_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS) | |
target_image.thumbnail((2048, 2048), Image.Resampling.LANCZOS) | |
furniture_reference.thumbnail((1024, 1024), Image.Resampling.LANCZOS) | |
room_image_input_base64 = encode_image_as_base64(target_image) | |
room_image_mask_base64 = encode_image_as_base64(mask_image) | |
furniture_reference_base64 = encode_image_as_base64(furniture_reference) | |
room_image_input_base64 = "data:image/png;base64," + room_image_input_base64 | |
room_image_mask_base64 = "data:image/png;base64," + room_image_mask_base64 | |
furniture_reference_base64 = "data:image/png;base64," + furniture_reference_base64 | |
response = requests.post( | |
ENDPOINT, | |
headers={"accept": "application/json", "Content-Type": "application/json"}, | |
json={ | |
"model_type": model_type, | |
"room_image_input": room_image_input_base64, | |
"room_image_mask": room_image_mask_base64, | |
"furniture_reference_image": furniture_reference_base64, | |
"prompt": prompt, | |
"subfolder": subfolder, | |
"seed": seed, | |
"num_inference_steps": num_inference_steps, | |
"max_dimension": max_dimension, | |
"condition_scale": 1.0, | |
"margin": margin, | |
"crop": crop, | |
"num_images_per_prompt": num_images_per_prompt, | |
"password": PASSWORD, | |
}, | |
) | |
if response.status_code != 200: | |
gr.Info("An error occurred during the generation") | |
return None | |
# Read the returned ZIP file from the response. | |
zip_bytes = io.BytesIO(response.content) | |
final_image_list: list[Image.Image] = [] | |
# Open the ZIP archive. | |
with zipfile.ZipFile(zip_bytes, "r") as zip_file: | |
image_filenames = zip_file.namelist() | |
for filename in image_filenames: | |
with zip_file.open(filename) as file: | |
image = Image.open(file).convert("RGB") | |
final_image_list.append(image) | |
return final_image_list | |
css = r""" | |
#col-left { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-mid { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-right { | |
margin: 0 auto; | |
max-width: 430px; | |
} | |
#col-showcase { | |
margin: 0 auto; | |
max-width: 1100px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(""" | |
<div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
<h1 style="color: #333;">🪑 Furniture Blending Demo</h1> | |
<div style="max-width: 800px; margin: 0 auto;"> | |
<p style="font-size: 16px;">Upload an image, draw a mask on the areas you want to remove, and upload a furniture reference image.</p> | |
<p style="font-size: 16px;"> | |
For the best results, make square masks. | |
Flux dev give better results than the schnell but is slower. | |
Object reference should be a single object with white background. | |
</p> | |
<p style="font-size: 16px;"> | |
You can edit the object with the prompt. | |
For example, you can add "red couch" to the prompt to make the couch red. | |
</p> | |
<br> | |
<p style="font-size: 16px;">⚠️ Note that the images are compressed to reduce the workloads of the demo. </p> | |
</div> | |
</div> | |
""") | |
with gr.Row() as content: | |
with gr.Column(elem_id="col-left"): | |
gr.HTML( | |
r""" | |
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
<div> | |
🪟 Room image with inpainting mask ⬇️ | |
</div> | |
</div> | |
""", | |
max_height=50, | |
) | |
image_and_mask = gr.ImageMask( | |
label="Image and Mask", | |
layers=False, | |
height="full", | |
width="full", | |
show_fullscreen_button=False, | |
sources=["upload"], | |
show_download_button=False, | |
interactive=True, | |
brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), | |
transforms=[], | |
) | |
image_and_mask_examples = gr.Examples( | |
examples=[ | |
make_example(path, None) | |
for path in Path("./examples/scenes").glob("*.png") | |
], | |
label="Room examples", | |
examples_per_page=6, | |
inputs=[image_and_mask], | |
) | |
with gr.Column(elem_id="col-mid"): | |
gr.HTML( | |
r""" | |
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
<div> | |
🪑 Furniture reference image ⬇️ | |
</div> | |
</div> | |
""", | |
max_height=50, | |
) | |
condition_image = gr.Image( | |
label="Furniture Reference", | |
type="pil", | |
sources=["upload"], | |
image_mode="RGB", | |
) | |
furniture_examples = gr.Examples( | |
examples=list(Path("./examples/objects").glob("*.png")), | |
label="Furniture examples", | |
examples_per_page=6, | |
inputs=[condition_image], | |
) | |
with gr.Column(elem_id="col-right"): | |
gr.HTML( | |
r""" | |
<div style="display: flex; justify-content: start; align-items: center; text-align: center; font-size: 20px"> | |
<div> | |
🔥 Press Run ⬇️ | |
</div> | |
</div> | |
""", | |
max_height=50, | |
) | |
results = gr.Gallery( | |
label="Result", | |
format="png", | |
file_types="image", | |
show_label=False, | |
columns=2, | |
allow_preview=True, | |
preview=True, | |
) | |
model_type = gr.Radio( | |
choices=["schnell", "dev", "pixart"], | |
value="schnell", | |
label="Model Type", | |
) | |
run_button = gr.Button("Run") | |
with gr.Accordion("Advanced Settings", open=False): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="", | |
) | |
subfolder = gr.Textbox( | |
label="Subfolder", | |
value="", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=np.iinfo(np.int32).max, | |
step=1, | |
value=0, | |
) | |
num_images_per_prompt = gr.Slider( | |
label="Number of images per prompt", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=4, | |
) | |
crop = gr.Checkbox( | |
label="Crop", | |
value=False, | |
) | |
margin = gr.Slider( | |
label="Margin", | |
minimum=0, | |
maximum=256, | |
step=16, | |
value=128, | |
) | |
with gr.Column(): | |
max_dimension = gr.Slider( | |
label="Max Dimension", | |
minimum=256, | |
maximum=1024, | |
step=128, | |
value=512, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=4, | |
maximum=30, | |
step=2, | |
value=4, | |
) | |
# Change the number of inference steps based on the model type | |
model_type.change( | |
fn=lambda x: gr.update(value=4 if x == "schnell" else 28), | |
inputs=model_type, | |
outputs=num_inference_steps, | |
) | |
run_button.click( | |
fn=predict, | |
inputs=[ | |
model_type, | |
image_and_mask, | |
condition_image, | |
prompt, | |
subfolder, | |
seed, | |
num_inference_steps, | |
max_dimension, | |
margin, | |
crop, | |
num_images_per_prompt, | |
], | |
outputs=[results], | |
) | |
demo.launch() | |