##!/usr/bin/python3
# -*- coding: utf-8 -*-
import gradio as gr
import os, sys
import json
import copy
import spaces
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.utils import save_image
from transformers import AutoImageProcessor, Dinov2Model
from segment_anything import SamPredictor, sam_model_registry
from diffusers import (
StableDiffusionBlobNetPipeline,
BlobNetModel,
UNet2DConditionModel,
UniPCMultistepScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
)
from huggingface_hub import snapshot_download
sys.path.append(os.getcwd()+ '/examples/blobctrl')
from utils.utils import splat_features, viz_score_fn, BLOB_VIS_COLORS, vis_gt_ellipse_from_ellipse
weight_dtype = torch.float16
device = "cuda"
# download blobctrl models
BlobCtrl_path = "examples/blobctrl/models"
if not (os.path.exists(f"{BlobCtrl_path}/blobnet") and os.path.exists(f"{BlobCtrl_path}/unet_lora")):
BlobCtrl_path = snapshot_download(
repo_id="Yw22/BlobCtrl",
local_dir=BlobCtrl_path,
token=os.getenv("HF_TOKEN"),
)
print(f"BlobCtrl checkpoints downloaded to {BlobCtrl_path}")
# download stable-diffusion-v1-5
StableDiffusion_path = "examples/blobctrl/models/stable-diffusion-v1-5"
if not os.path.exists(StableDiffusion_path):
StableDiffusion_path = snapshot_download(
repo_id="sd-legacy/stable-diffusion-v1-5",
local_dir=StableDiffusion_path,
token=os.getenv("HF_TOKEN"),
)
print(f"StableDiffusion checkpoints downloaded to {StableDiffusion_path}")
# download dinov2-large
Dino_path = "examples/blobctrl/models/dinov2-large"
if not os.path.exists(Dino_path):
Dino_path = snapshot_download(
repo_id="facebook/dinov2-large",
local_dir=Dino_path,
token=os.getenv("HF_TOKEN"),
)
print(f"Dino checkpoints downloaded to {Dino_path}")
# download SAM model
SAM_path = "examples/blobctrl/models/sam/sam_vit_h_4b8939.pth"
if not os.path.exists(SAM_path):
os.makedirs(os.path.dirname(SAM_path), exist_ok=True)
import urllib.request
print(f"Downloading SAM model...")
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
SAM_path
)
print(f"SAM model downloaded to {SAM_path}")
## load models and pipeline
blobnet_path = "./examples/blobctrl/models/blobnet"
unet_lora_path = "./examples/blobctrl/models/unet_lora"
stabel_diffusion_model_path = "./examples/blobctrl/models/stable-diffusion-v1-5"
dinov2_path = "./examples/blobctrl/models/dinov2-large"
sam_path = "./examples/blobctrl/models/sam/sam_vit_h_4b8939.pth"
## unet
print(f"Loading UNet...")
unet = UNet2DConditionModel.from_pretrained(
stabel_diffusion_model_path,
subfolder="unet",
)
with torch.no_grad():
initial_input_channels = unet.config.in_channels
new_conv_in = torch.nn.Conv2d(
initial_input_channels + 1,
unet.conv_in.out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=unet.conv_in.bias is not None,
dtype=unet.dtype,
device=unet.device,
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :initial_input_channels].copy_(unet.conv_in.weight)
if unet.conv_in.bias is not None:
new_conv_in.bias.copy_(unet.conv_in.bias)
unet.conv_in = new_conv_in
## blobnet
print(f"Loading BlobNet...")
blobnet = BlobNetModel.from_pretrained(blobnet_path, ignore_mismatched_sizes=True)
## sam
print(f"Loading SAM...")
mobile_sam = sam_model_registry['vit_h'](checkpoint=sam_path).to(device)
mobile_sam.eval()
mobile_predictor = SamPredictor(mobile_sam)
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]
rgba_colors = [(255, 0, 255, 255), (0, 255, 0, 255), (0, 0, 255, 255)]
## dinov2
print(f"Loading Dinov2...")
dinov2_processor = AutoImageProcessor.from_pretrained(dinov2_path)
dinov2 = Dinov2Model.from_pretrained(dinov2_path).to(device)
## stable diffusion with blobnet pipeline
print(f"Loading StableDiffusionBlobNetPipeline...")
pipeline = StableDiffusionBlobNetPipeline.from_pretrained(
stabel_diffusion_model_path,
unet=unet,
blobnet=blobnet,
torch_dtype=weight_dtype,
dinov2_processor=dinov2_processor,
dinov2=dinov2,
)
print(f"Loading UNetLora...")
pipeline.load_lora_weights(
unet_lora_path,
adapter_name="default",
)
pipeline.set_adapters(["default"])
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
# pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.to(device)
pipeline.set_progress_bar_config(leave=False)
## meta info
logo = r"""
"""
head= r"""
BlobCtrl: A Unified and Flexible Framework for Element-level Image Generation and Editing
"""
descriptions = r"""
Official Gradio Demo for BlobCtrl: A Unified and Flexible Framework for Element-level Image Generation and Editing
🦉 BlobCtrl enables precise, user-friendly element-level visual manipulation.
Main Features: Element-level Add/Remove/Move/Replace/Enlarge/Shrink.
"""
citation = r"""
If BlobCtrl is helpful, please help to ⭐ the Github Repo. Thanks!
[](https://github.com/TencentARC/BlobCtrl)
---
📝 **Citation**
If our work is useful for your research, please consider citing:
```bibtex
@misc{li2025blobctrl,
title={BlobCtrl: A Unified and Flexible Framework for Element-level Image Generation and Editing},
author={Yaowei Li, Lingen Li, Zhaoyang Zhang, Xiaoyu Li, Guangzhi Wang, Hongxiang Li, Xiaodong Cun, Ying Shan, Yuexian Zou},
year={2025},
eprint={2503.13434},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
📧 **Contact**
If you have any questions, please feel free to reach me out at liyaowei@gmail.com.
"""
# - - - - - examples - - - - - #
EXAMPLES= [
[
"examples/blobctrl/assets/results/demo/move_hat/input_image/input_image.png",
"A frog sits on a rock in a pond, with a top hat beside it, surrounded by butterflies and vibrant flowers.",
1.0,
0.0,
0.9,
1248464818,
0,
],
[
"examples/blobctrl/assets/results/demo/move_cup/input_image/input_image.png",
"a rustic wooden table.",
1.0,
0.0,
1.0,
1248464818,
1,
],
[
"examples/blobctrl/assets/results/demo/enlarge_deer/input_image/input_image.png",
"A cute, young deer with large ears standing in a grassy field at sunrise, surrounded by trees.",
1.6,
0.0,
1.0,
1288911487,
2,
],
[
"examples/blobctrl/assets/results/demo/shrink_dragon/input_image/input_image.png",
"A detailed, handcrafted cardboard dragon with red wings and expressive eyes.",
1.0,
0.0,
1.0,
1248464818,
3,
],
[
"examples/blobctrl/assets/results/demo/remove_shit/input_image/input_image.png",
"The background consists of a textured, gray concrete surface with a red brick wall behind it. The bricks are arranged in a classic pattern, showcasing various shades of red and some weathering.",
1.0,
0.0,
1.0,
1248464818,
4,
],
[
"examples/blobctrl/assets/results/demo/remove_cow/input_image/input_image.png",
"A majestic mountain range with rugged peaks under a cloudy sky, and a grassy field in the foreground.",
1.0,
0.0,
1.0,
1248464818,
5,
],
[
"examples/blobctrl/assets/results/demo/compose_rabbit/input_image/input_image.png",
"A cute brown rabbit sitting on a wooden surface with a serene lake and mountains in the background.",
1.0,
0.0,
1.0,
1248464818,
6,
],
[
"examples/blobctrl/assets/results/demo/compose_cake/input_image/input_image.png",
" slice of cake on a light blue background.",
1.2,
0.0,
1.0,
1248464818,
7,
],
[
"examples/blobctrl/assets/results/demo/replace_knife/input_image/input_image.png",
"A slice of cake on a light blue background, with a knife in the center.",
1.2,
0.0,
1.0,
1248464818,
8,
]
]
#
OBJECT_IMAGE_GALLERY = [
["examples/blobctrl/assets/results/demo/move_hat/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/move_cup/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/enlarge_deer/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/shrink_dragon/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/remove_shit/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/remove_cow/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/compose_rabbit/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/compose_cake/object_image_gallery/validation_object_region_center.png"],
["examples/blobctrl/assets/results/demo/replace_knife/object_image_gallery/validation_object_region_center.png"],
]
ORI_RESULT_GALLERY = [
["examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/move_hat/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/move_cup/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/enlarge_deer/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/shrink_dragon/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/remove_shit/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/remove_cow/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/compose_rabbit/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/compose_cake/ori_result_gallery/ori_result_gallery_4.png"],
["examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_0.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_1.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_2.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_3.png", "examples/blobctrl/assets/results/demo/replace_knife/ori_result_gallery/ori_result_gallery_4.png"],
]
EDITABLE_BLOB = [
"examples/blobctrl/assets/results/demo/move_hat/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/move_cup/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/enlarge_deer/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/shrink_dragon/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/remove_shit/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/remove_cow/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/compose_rabbit/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/compose_cake/editable_blob/editable_blob.png",
"examples/blobctrl/assets/results/demo/replace_knife/editable_blob/editable_blob.png",
]
EDITED_RESULT_GALLERY = [
["examples/blobctrl/assets/results/demo/move_hat/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_hat/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/move_cup/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/move_cup/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/enlarge_deer/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/enlarge_deer/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/shrink_dragon/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/shrink_dragon/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/remove_shit/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_shit/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/remove_cow/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_cow/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/compose_rabbit/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_rabbit/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/compose_cake/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_cake/edited_result_gallery/edited_result_gallery_1.png"],
["examples/blobctrl/assets/results/demo/replace_knife/edited_result_gallery/edited_result_gallery_0.png", "examples/blobctrl/assets/results/demo/replace_knife/edited_result_gallery/edited_result_gallery_1.png"],
]
RESULTS_GALLERY = [
["examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/move_hat/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/move_cup/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/enlarge_deer/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/shrink_dragon/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_shit/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/remove_cow/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_rabbit/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/compose_cake/results_gallery/results_gallery_3.png"],
["examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_0.png", "examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_1.png", "examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_2.png", "examples/blobctrl/assets/results/demo/replace_knife/results_gallery/results_gallery_3.png"],
]
ELLIPSE_LISTS = [
[[[[227.10665893554688, 118.85255432128906], [85.48122482299804, 103.65433502197266], 87.37393951416016], [1, 1, 1, 0], 0], [[[361.1066589355469, 367.85255432128906], [85.48122482299804, 103.65433502197266], 87.37393951416016], [1, 1, 1, 0], 1]],
[[[[249.1703643798828, 149.63021850585938], [83.36424179077149, 115.79973449707032], 0.8257154226303101], [1, 1, 1, 0], 0], [[[245.1703643798828, 270.6302185058594], [83.36424179077149, 115.79973449707032], 0.8257154226303101], [1, 1, 1, 0], 1]],
[[[[234.69358825683594, 255.60946655273438], [196.208619140625, 341.067111328125], 15.866915702819824], [1, 1, 1, 0], 0], [[[234.69358825683594, 255.60946655273438], [226.394560546875, 393.538974609375], 15.866915702819824], [1.2, 1, 1, 0], 2], [[[234.69358825683594, 255.60946655273438], [237.71428857421876, 413.21592333984376], 15.866915702819824], [1.05, 1, 1, 0], 2], [[[237.69358825683594, 237.60946655273438], [237.71428857421876, 413.21592333984376], 15.866915702819824], [1.05, 1, 1, 0], 1], [[[237.69358825683594, 233.60946655273438], [237.71428857421876, 413.21592333984376], 15.866915702819824], [1.05, 1, 1, 0], 1]],
[[[[367.17742919921875, 201.1094512939453], [206.3889125696118, 377.8448820272314], 56.17562484741211], [1, 1, 1, 0], 0], [[[367.17742919921875, 201.1094512939453], [147.91468688964844, 297.0842980957031], 56.17562484741211], [0.8, 1, 1, 0], 2], [[[367.17742919921875, 201.1094512939453], [140.518952545166, 282.2300831909179], 56.17562484741211], [0.95, 1, 1, 0], 2], [[[324.17742919921875, 235.1094512939453], [140.518952545166, 282.2300831909179], 56.17562484741211], [0.95, 1, 1, 0], 1], [[[335.17742919921875, 225.1094512939453], [140.518952545166, 282.2300831909179], 56.17562484741211], [0.95, 1, 1, 0], 1]],
[[[[255.23663330078125, 315.4020080566406], [263.64675201416014, 295.38494384765625], 153.8949432373047], [1, 1, 1, 0], 0]],
[[[[335.09979248046875, 236.41409301757812], [168.37833966064454, 345.3470615478516], 0.7639619708061218], [1, 1, 1, 0], 0]],
[[[[256.0, 256.0], [1e-05, 1e-05], 0], [1, 1, 1, 0], 0], [[[271.6672, 275.3536], [136.85061800371966, 303.75044578074284], 177.008], [1, 1, 1, 0], 0], [[[271.6672, 275.3536], [150.53567980409164, 303.75044578074284], 177.008], [1.1, 1, 1.1, 0], 4], [[[271.6672, 275.3536], [158.06246379429624, 318.93796806977997], 177.008], [1.05, 1, 1.1, 0], 2], [[[271.6672, 275.3536], [165.96558698401105, 334.88486647326897], 177.008], [1.05, 1, 1.1, 0], 2], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 177.008], [1.1, 1, 1.1, 0], 4], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 7.00800000000001], [1.1, 1, 1.1, 10], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 3.0080000000000098], [1.1, 1, 1.1, -4], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 177.008], [1.1, 1, 1.1, -6], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 179.008], [1.1, 1, 1.1, 2], 5], [[[271.6672, 275.3536], [182.56214568241217, 334.88486647326897], 178.008], [1.1, 1, 1.1, -1], 5], [[[271.6672, 275.3536], [182.56214568241217, 368.3733531205959], 178.008], [1.1, 1.1, 1.1, -1], 3], [[[271.6672, 275.3536], [182.56214568241217, 349.95468546456607], 178.008], [1.1, 0.95, 1.1, -1], 3], [[[271.6672, 275.3536], [182.56214568241217, 349.95468546456607], 170.008], [1.1, 0.95, 1.1, -8], 5], [[[271.6672, 275.3536], [182.56214568241217, 349.95468546456607], 172.008], [1.1, 0.95, 1.1, 2], 5]],
[[[[256.0, 256.0], [1e-05, 1e-05], 0], [1, 1, 1, 0], 0], [[[256.0, 256.0], [144.81546878700496, 144.81546878700496], 0], [1, 1, 1, 0], 0], [[[256.0, 256.0], [123.09314846895421, 123.09314846895421], 0], [0.85, 1, 1, 0], 2], [[[256.0, 256.0], [110.7838336220588, 110.7838336220588], 0], [0.9, 1, 1, 0], 2], [[[88.0, 418.0], [110.7838336220588, 110.7838336220588], 0], [0.9, 1, 1, 0], 1]],
[[[[164.6718292236328, 385.8408508300781], [41.45796089172364, 319.87034912109374], 142.05267333984375], [1, 1, 1, 0], 0]],
]
TRACKING_POINTS = [
[[227, 118], [361, 367]],
[[249, 150], [248, 269]],
[[234, 255], [234, 255], [234, 255], [237, 237], [237, 233]],
[[367, 201], [367, 201], [367, 201], [324, 235], [335, 225]],
[[255, 315]],
[[335, 236]],
[[256, 256], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271], [275, 271]],
[[256, 256], [256, 256], [256, 256], [256, 256], [88, 418]],
[[164, 385]],
]
REMOVE_STATE=[
False,
False,
False,
False,
True,
True,
False,
False,
False,
]
INPUT_IMAGE=[
"examples/blobctrl/assets/results/demo/move_hat/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/move_cup/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/enlarge_deer/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/shrink_dragon/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/remove_shit/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/remove_cow/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/compose_rabbit/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/compose_cake/input_image/input_image.png",
"examples/blobctrl/assets/results/demo/replace_knife/input_image/input_image.png",
]
## normal functions
def _get_ellipse(mask):
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours_copy = [contour_copy.tolist() for contour_copy in contours]
concat_contours = np.concatenate(contours, axis=0)
hull = cv2.convexHull(concat_contours)
ellipse = cv2.fitEllipse(hull)
return ellipse, contours_copy
def ellipse_to_gaussian(x, y, a, b, theta):
"""
将椭圆参数转换为高斯分布的均值和协方差矩阵。
参数:
x (float): 椭圆中心的 x 坐标。
y (float): 椭圆中心的 y 坐标。
a (float): 椭圆的短半轴长度。
b (float): 椭圆的长半轴长度。
theta (float): 椭圆的旋转角度(以弧度为单位), 长半轴逆时针角度。
返回:
mean (numpy.ndarray): 高斯分布的均值,形状为 (2,) 的数组,表示 (x, y) 坐标。
cov_matrix (numpy.ndarray): 高斯分布的协方差矩阵,形状为 (2, 2) 的数组。
"""
# 均值
mean = np.array([x, y])
# 协方差的主对角线元素
# sigma_x = b / np.sqrt(2)
# sigma_y = a / np.sqrt(2)
# 不除以 sqrt(2) 也是可以的。这个转换主要是为了在特定的统计上下文中,
# 使得椭圆的半轴长度对应于高斯分布的一个标准差。
# 这样做的目的是为了使得椭圆的面积包含了高斯分布约68%的概率质量(在一维高斯分布中,一个标准差的范围内包含了约68%的概率质量)。
# 协方差的主对角线元素
sigma_x = b
sigma_y = a
# 协方差矩阵(未旋转)
cov_matrix = np.array([[sigma_x**2, 0],
[0, sigma_y**2]])
# 旋转矩阵
R = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
# 旋转协方差矩阵
cov_matrix_rotated = R @ cov_matrix @ R.T
cov_matrix_rotated[0, 1] *= -1 # 反转协方差矩阵的非对角元素
cov_matrix_rotated[1, 0] *= -1 # 反转协方差矩阵的非对角元素
# eigenvalues, eigenvectors = np.linalg.eig(cov_matrix_rotated)
return mean, cov_matrix_rotated
def normalize_gs(mean, cov_matrix_rotated, width, height):
# 归一化 mean
normalized_mean = mean / np.array([width, height])
# 计算最大长度用于归一化协方差矩阵
max_length = np.sqrt(width**2 + height**2)
# 归一化协方差矩阵
normalized_cov_matrix = cov_matrix_rotated / (max_length ** 2)
return normalized_mean, normalized_cov_matrix
def normalize_ellipse(ellipse, width, height):
(xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse
max_length = np.sqrt(width**2 + height**2)
normalized_xc, normalized_yc = xc/width, yc/height
normalized_d1, normalized_d2 = d1/max_length, d2/max_length
return normalized_xc, normalized_yc, normalized_d1, normalized_d2, angle_clockwise_short_axis
def composite_mask_and_image(mask, image, masked_color=[0,0,0]):
if isinstance(mask, Image.Image):
mask_np = np.array(mask)
else:
mask_np = mask
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
if mask_np.ndim == 2:
mask_indicator = (mask_np>0).astype(np.uint8)
else:
mask_indicator = (mask_np.sum(-1)>255).astype(np.uint8)
masked_image = image_np * (1-mask_indicator[:,:,np.newaxis]) + masked_color * mask_indicator[:,:,np.newaxis]
return Image.fromarray(masked_image.astype(np.uint8)).convert("RGB")
def is_point_in_ellipse(point, ellipse):
# 提取椭圆参数
(xc, yc), (d1, d2), angle = ellipse
# 将角度转换为弧度
theta = np.radians(angle)
# 计算相对坐标
x, y = point
x_prime = x - xc
y_prime = y - yc
# 计算旋转后的坐标
x_rotated = x_prime * np.cos(theta) - y_prime * np.sin(theta)
y_rotated = x_prime * np.sin(theta) + y_prime * np.cos(theta)
# 计算椭圆方程,d1 和 d2 是全长轴和全短轴,需除以 2
ellipse_equation = (x_rotated**2) / ((d1 / 2)**2) + (y_rotated**2) / ((d2 / 2)**2)
# 判断点是否在椭圆内
return ellipse_equation <= 1
def calculate_ellipse_vertices(ellipse):
(xc, yc), (d1, d2), angle_clockwise_short_axis = ellipse
# Convert angle from degrees to radians
angle_rad = np.deg2rad(angle_clockwise_short_axis)
# Calculate the rotation matrix
rotation_matrix = np.array([
[np.cos(angle_rad), -np.sin(angle_rad)],
[np.sin(angle_rad), np.cos(angle_rad)]
])
# Calculate the unrotated vertices
half_d1 = d1 / 2
half_d2 = d2 / 2
vertices = np.array([
[half_d1, 0], # Rightmost point on the long axis
[-half_d1, 0], # Leftmost point on the long axis
[0, half_d2], # Topmost point on the short axis
[0, -half_d2] # Bottommost point on the short axis
])
# Rotate the vertices
rotated_vertices = np.dot(vertices, rotation_matrix.T)
# Translate vertices to the original center
final_vertices = rotated_vertices + np.array([xc, yc])
return final_vertices
def move_ellipse(ellipse, tracking_points):
(xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse
last_xc, last_yc = tracking_points[-1]
second_last_xc, second_last_yc = tracking_points[-2]
vx = last_xc - second_last_xc
vy = last_yc - second_last_yc
xc += vx
yc += vy
return (xc,yc), (d1,d2), angle_clockwise_short_axis
def resize_blob_func(ellipse, resizing_factor, height, width, resize_type):
(xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse
too_big = False
too_small = False
min_blob_area = 1600
exceed_threshold = 0.4
while True:
if resize_type == 0:
resized_d1 = d1 * resizing_factor
resized_d2 = d2 * resizing_factor
elif resize_type == 1:
resized_d1 = d1
resized_d2 = d2 * resizing_factor
elif resize_type == 2:
resized_d1 = d1 * resizing_factor
resized_d2 = d2
resized_ellipse = (xc,yc), (resized_d1, resized_d2), angle_clockwise_short_axis
resized_ellipse_vertices = calculate_ellipse_vertices(resized_ellipse)
resized_ellipse_vertices = resized_ellipse_vertices / np.array([width, height])
if resizing_factor != 1:
# soft the threshold allowed to exceed the image range
if np.all(resized_ellipse_vertices >= -exceed_threshold) and np.all(resized_ellipse_vertices <= 1+exceed_threshold):
# calculate the blob area
blob_area = np.pi * (resized_d1 / 2) * (resized_d2 / 2)
if blob_area >= min_blob_area:
break
else:
too_small = True
resizing_factor += 0.1
if blob_area < 1e-6:
## if the blob area is too too too small, break
break
else:
too_big = True
resizing_factor -= 0.1
else:
break
if too_big:
gr.Warning(f"The blob is too big, adaptive reduction of magnification to fit the image, The threshold allowed to exceed the image range is {exceed_threshold}")
if too_small:
gr.Warning(f"The blob is too small, adaptive enlargement of magnification to fit the image, The minimum blob area is {min_blob_area} px")
return resized_ellipse, resizing_factor
def rotate_blob_func(ellipse, rotation_degree):
(xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse
rotated_angle_clockwise_short_axis = (angle_clockwise_short_axis + rotation_degree) % 180
rotated_ellipse = (xc,yc), (d1,d2), rotated_angle_clockwise_short_axis
return rotated_ellipse, rotation_degree
def get_theta_anti_clockwise_long_axis(angle_clockwise_short_axis):
angle_anti_clockwise_short_axis = (180 - angle_clockwise_short_axis) % 180
angle_anti_clockwise_long_axis = (angle_anti_clockwise_short_axis + 90) % 180
theta_anti_clockwise_long_axis = np.radians(angle_anti_clockwise_long_axis)
return theta_anti_clockwise_long_axis
def get_gs_from_ellipse(ellipse):
(xc,yc), (d1,d2), angle_clockwise_short_axis = ellipse
theta_anti_clockwise_long_axis = get_theta_anti_clockwise_long_axis(angle_clockwise_short_axis)
a = d1 / 2
b = d2 / 2
mean, cov_matrix = ellipse_to_gaussian(xc, yc, a, b, theta_anti_clockwise_long_axis)
return mean, cov_matrix
def get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix):
xs, ys = normalized_mean
blob = {
"xs": torch.tensor(xs).unsqueeze(0),
"ys": torch.tensor(ys).unsqueeze(0),
"covs": torch.tensor(normalized_cov_matrix).unsqueeze(0).unsqueeze(0),
"sizes": torch.tensor([1.0]).unsqueeze(0),
}
return blob
def clear_ellipse_lists(ellipse_lists):
ellipse_lists = []
return ellipse_lists
def get_blob_vis_img_from_blob_dict(blob, viz_size=64, score_size=64):
blob_vis = splat_features(**blob,
interp_size=64,
viz_size=viz_size,
is_viz=True,
ret_layout=True,
score_size=score_size,
viz_score_fn=viz_score_fn,
viz_colors=BLOB_VIS_COLORS,
only_vis=True)["feature_img"]
blob_vis_img = blob_vis[0].permute(1,2,0).contiguous().cpu().numpy()
blob_vis_img = (blob_vis_img*255).astype(np.uint8)
blob_vis_img = Image.fromarray(blob_vis_img)
return blob_vis_img
def get_blob_score_from_blob_dict(blob, score_size=64):
blob_score = splat_features(**blob,
score_size=score_size,
return_d_score=True,
)[0]
return blob_score
def get_object_region_from_mask(mask, original_image):
if isinstance(mask, Image.Image):
mask_np = np.array(mask)
else:
mask_np = mask
if mask_np.ndim == 2:
mask_indicator = (mask_np>0).astype(np.uint8)
else:
mask_indicator = (mask_np.sum(-1)>255).astype(np.uint8)
x, y, w, h = cv2.boundingRect(mask_indicator)
rect_mask = mask_indicator[y:y+h, x:x+w]
tmp = original_image.copy()
rect_region = tmp[y:y+h, x:x+w]
rect_region_object_white_background = np.where(rect_mask[:, :, None] > 0, rect_region, 255)
target_height, target_width = tmp.shape[:2]
start_y = (target_height - h) // 2
start_x = (target_width - w) // 2
rect_region_object_white_background_center = np.ones((target_height, target_width, 3), dtype=np.uint8) * 255
rect_region_object_white_background_center[start_y:start_y+h, start_x:start_x+w] = rect_region_object_white_background
rect_region_object_white_background_center = Image.fromarray(rect_region_object_white_background_center).convert("RGB")
return rect_region_object_white_background_center
def extract_contours(object_image):
"""
从物体图像中提取轮廓
:param object_image: 输入的物体图像,形状为 (h, w, 3),值范围为 [0, 255]
:return: 轮廓图像,
"""
# 将图像转换为灰度图
gray_image = cv2.cvtColor(object_image, cv2.COLOR_BGR2GRAY)
# 将图像二值化,假设物体不是白色 [255, 255, 255]
_, binary_image = cv2.threshold(gray_image, 240, 255, cv2.THRESH_BINARY_INV)
# 提取轮廓
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# 创建一个空白图像用于绘制轮廓
contour_image = np.zeros_like(gray_image)
# 在空白图像上绘制轮廓
cv2.drawContours(contour_image, contours, -1, (255), thickness=cv2.FILLED)
return contour_image
def get_mask_from_ellipse(ellipse, height, width):
ellipse_mask_np = np.zeros((height, width))
ellipse_mask_np = cv2.ellipse(ellipse_mask_np, ellipse, 255, -1)
ellipse_mask = Image.fromarray(ellipse_mask_np).convert("L")
return ellipse_mask
# gradio functions
@spaces.GPU(duration=100)
def run_function(
original_image,
scene_prompt,
ori_result_gallery,
object_image_gallery,
edited_result_gallery,
ellipse_lists,
blobnet_control_strength,
blobnet_control_guidance_start,
blobnet_control_guidance_end,
remove_blob_box,
num_samples,
seed,
guidance_scale,
num_inference_steps,
## for save
editable_blob,
resize_blob_slider_maintain_aspect_ratio,
resize_blob_slider_along_long_axis,
resize_blob_slider_along_short_axis,
rotation_blob_slider,
resize_init_blob_slider,
resize_init_blob_slider_long_axis,
resize_init_blob_slider_short_axis,
tracking_points,
):
if object_image_gallery == [] or object_image_gallery == None or ori_result_gallery == [] or ori_result_gallery == None:
gr.Warning("Please generate the blob first")
return None
if edited_result_gallery == [] or edited_result_gallery == None:
gr.Warning("Please click the region in the blob in the first time.")
return None
generator = torch.Generator(device=device).manual_seed(seed)
## prepare img: object_region_center, edited_background_region
gt_i_ellipse_img_path, masked_image_path, mask_image_path, ellipse_mask_path, ellipse_masked_image_path = ori_result_gallery
object_white_background_center_path = object_image_gallery[0]
validation_object_region_center = Image.open(object_white_background_center_path[0])
ori_ellipse_mask = Image.open(ellipse_mask_path[0])
width, height = validation_object_region_center.size
latent_height, latent_width = height // 8, width // 8
if not remove_blob_box:
edited_ellipse_masked_image_path, edited_ellipse_mask_path = edited_result_gallery
validation_edited_background_region = Image.open(edited_ellipse_masked_image_path[0])
## prepare gs_score
final_ellipse, final_transform_param, final_blob_edited_type = ellipse_lists[-1]
mean, cov_matrix = get_gs_from_ellipse(final_ellipse)
normalized_mean, normalized_cov_matrix = normalize_gs(mean, cov_matrix, width, height)
blob_dict = get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix)
validation_gs_score = get_blob_score_from_blob_dict(blob_dict, score_size=(latent_height, latent_width)).unsqueeze(0).to(device) # bnhw
else:
img_tmp = original_image.copy()
validation_edited_background_region = composite_mask_and_image(ori_ellipse_mask, img_tmp, masked_color=[255,255,255])
## prepare gs_score
start_ellipse, start_transform_param, start_blob_edited_type = ellipse_lists[0]
mean, cov_matrix = get_gs_from_ellipse(start_ellipse)
normalized_mean, normalized_cov_matrix = normalize_gs(mean, cov_matrix, width, height)
blob_dict = get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix)
validation_gs_score = get_blob_score_from_blob_dict(blob_dict, score_size=(latent_height, latent_width)).unsqueeze(0).to(device) # bnhw
validation_gs_score[:,0] = 1.0
validation_gs_score[:,1] = 0.0
final_ellipse = start_ellipse
## set blobnet control strength to 0.0
blobnet_control_strength = 0.0
with torch.autocast("cuda"):
output = pipeline(
fg_image=validation_object_region_center,
bg_image=validation_edited_background_region,
gs_score=validation_gs_score,
generator=generator,
prompt=[scene_prompt]*num_samples,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
blobnet_control_guidance_start=float(blobnet_control_guidance_start),
blobnet_control_guidance_end=float(blobnet_control_guidance_end),
blobnet_conditioning_scale=float(blobnet_control_strength),
width=width,
height=height,
return_sample=False,
)
edited_images = output.images
edit_image_plots = []
for i in range(num_samples):
edit_image_np = np.array(edited_images[i])
edit_image_np_plot = cv2.ellipse(edit_image_np, final_ellipse, [0,255,0], 3)
edit_image_plot = Image.fromarray(edit_image_np_plot).convert("RGB")
edit_image_plots.append(edit_image_plot)
results_gallery = [*edited_images, *edit_image_plots]
## save results
# ori_save_path = "examples/blobctrl/assets/results/tmp/ori_result_gallery"
# os.makedirs(ori_save_path, exist_ok=True)
# # import ipdb; ipdb.set_trace()
# for i in range(len(ori_result_gallery)):
# result = Image.open(ori_result_gallery[i][0])
# result.save(f"{ori_save_path}/ori_result_gallery_{i}.png")
# object_save_path = "examples/blobctrl/assets/results/tmp/object_image_gallery"
# os.makedirs(object_save_path, exist_ok=True)
# validation_object_region_center.save(f"{object_save_path}/validation_object_region_center.png")
# edited_save_path = "examples/blobctrl/assets/results/tmp/edited_result_gallery"
# os.makedirs(edited_save_path, exist_ok=True)
# for i in range(len(edited_result_gallery)):
# result = Image.open(edited_result_gallery[i][0])
# result.save(f"{edited_save_path}/edited_result_gallery_{i}.png")
# results_save_path = "examples/blobctrl/assets/results/tmp/results_gallery"
# os.makedirs(results_save_path, exist_ok=True)
# for i in range(len(results_gallery)):
# results_gallery[i].save(f"{results_save_path}/results_gallery_{i}.png")
# editable_blob_save_path = "examples/blobctrl/assets/results/tmp/editable_blob"
# os.makedirs(editable_blob_save_path, exist_ok=True)
# editable_blob_pil = Image.fromarray(editable_blob)
# editable_blob_pil.save(f"{editable_blob_save_path}/editable_blob.png")
# state_save_path = "examples/blobctrl/assets/results/tmp/state"
# os.makedirs(state_save_path, exist_ok=True)
# with open(f"{state_save_path}/state.json", "w") as f:
# json.dump({
# "blobnet_control_strength": blobnet_control_strength,
# "blobnet_control_guidance_start": blobnet_control_guidance_start,
# "blobnet_control_guidance_end": blobnet_control_guidance_end,
# "remove_blob_box": remove_blob_box,
# "num_samples": num_samples,
# "seed": seed,
# "guidance_scale": guidance_scale,
# "num_inference_steps": num_inference_steps,
# "ellipse_lists": ellipse_lists,
# "scene_prompt": scene_prompt,
# "resize_blob_slider_maintain_aspect_ratio": resize_blob_slider_maintain_aspect_ratio,
# "resize_blob_slider_along_long_axis": resize_blob_slider_along_long_axis,
# "resize_blob_slider_along_short_axis": resize_blob_slider_along_short_axis,
# "rotation_blob_slider": rotation_blob_slider,
# "resize_init_blob_slider": resize_init_blob_slider,
# "resize_init_blob_slider_long_axis": resize_init_blob_slider_long_axis,
# "resize_init_blob_slider_short_axis": resize_init_blob_slider_short_axis,
# "tracking_points": tracking_points,
# }, f)
# input_image_save_path = "examples/blobctrl/assets/results/tmp/input_image"
# os.makedirs(input_image_save_path, exist_ok=True)
# Image.fromarray(original_image).save(f"{input_image_save_path}/input_image.png")
torch.cuda.empty_cache()
return results_gallery
def generate_blob(
original_image,
original_mask,
selected_points,
ellipse_lists,
init_resize_factor=1.05,
):
if original_image is None:
raise gr.Error('Please upload the input image')
if (original_mask is None) or (len(selected_points)==0):
raise gr.Error("Please click the region where you hope unchanged/changed in input image to get segmentation mask")
else:
original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8)
## get ellipse parameters from mask
height, width = original_image.shape[:2]
binary_mask = 255*(original_mask.sum(-1)>255).astype(np.uint8)
ellipse, contours = _get_ellipse(binary_mask)
## properly enlarge ellipse to cover the whole blob
ellipse, resizing_factor = resize_blob_func(ellipse, init_resize_factor, height, width, 0)
## get gaussian parameters from ellipse
mean, cov_matrix = get_gs_from_ellipse(ellipse)
normalized_mean, normalized_cov_matrix = normalize_gs(mean, cov_matrix, width, height)
blob_dict = get_blob_dict_from_norm_gs(normalized_mean, normalized_cov_matrix)
blob_vis_img = get_blob_vis_img_from_blob_dict(blob_dict, viz_size=(height, width))
## plot masked image
masked_image = composite_mask_and_image(original_mask, original_image)
mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("L")
## get object region
object_white_background_center = get_object_region_from_mask(original_mask, original_image)
## plot ellipse
gt_i_ellipse = vis_gt_ellipse_from_ellipse(torch.tensor(original_image).round().contiguous().cpu().numpy(),
ellipse,
color=[0,255,0])
gt_i_ellipse_img = Image.fromarray(gt_i_ellipse.astype(np.uint8))
ellipse_mask = get_mask_from_ellipse(ellipse, height, width)
ellipse_masked_image = composite_mask_and_image(ellipse_mask, original_image)
## return images
ori_result_gallery = [gt_i_ellipse_img, masked_image, mask_image, ellipse_mask, ellipse_masked_image]
object_image_gallery = [object_white_background_center]
## init ellipse_lists, 0: init, 1: move , 2: resize remain aspect ratio, 3: resize along long axis, 4: resize along short axis, 5: rotation
## ellipse_int = (ellipse, (resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle), blob_edited_type)
ellipse_init = (ellipse, (1, 1, 1, 0), 0)
if len(ellipse_lists) == 0:
ellipse_lists.append(ellipse_init)
else:
ellipse_lists = clear_ellipse_lists(ellipse_lists)
ellipse_lists.append(ellipse_init)
## init parameters
rotation_blob_slider = 0
resize_blob_slider_maintain_aspect_ratio = 1
resize_blob_slider_along_long_axis = 1
resize_blob_slider_along_short_axis = 1
resize_init_blob_slider = 1
resize_init_blob_slider_long_axis = 1
resize_init_blob_slider_short_axis = 1
init_ellipse_parameter = None
init_object_image = None
tracking_points = []
edited_result_gallery = None
return blob_vis_img, ori_result_gallery, object_image_gallery, ellipse_lists, tracking_points, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider, resize_init_blob_slider, resize_init_blob_slider_long_axis, resize_init_blob_slider_short_axis, init_ellipse_parameter, init_object_image
# undo the selected point
def undo_seg_points(orig_img, sel_pix):
# draw points
output_mask = None
if len(sel_pix) != 0:
temp = orig_img.copy()
sel_pix.pop()
# online show seg mask
if len(sel_pix) !=0:
temp, output_mask = segmentation(temp, sel_pix)
return temp.astype(np.uint8), output_mask
else:
gr.Warning("Nothing to Undo")
# once user upload an image, the original image is stored in `original_image`
def initialize_img(img):
if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
raise gr.Error('image aspect ratio cannot be larger than 2.0')
# Check if image needs resizing
# Resize and crop to 512x512
h, w = img.shape[:2]
# First resize so shortest side is 512
scale = 512 / min(h, w)
new_h = int(h * scale)
new_w = int(w * scale)
img = cv2.resize(img, (new_w, new_h))
# Then crop to 512x512
h, w = img.shape[:2]
start_y = (h - 512) // 2
start_x = (w - 512) // 2
img = img[start_y:start_y+512, start_x:start_x+512]
original_image = img.copy()
editable_blob = None
selected_points = []
tracking_points = []
ellipse_lists = []
ori_result_gallery = []
object_image_gallery = []
edited_result_gallery = []
results_gallery = []
blobnet_control_strength = 1.2
blobnet_control_guidance_start = 0.0
blobnet_control_guidance_end = 1.0
resize_blob_slider_maintain_aspect_ratio = 1
resize_blob_slider_along_long_axis = 1
resize_blob_slider_along_short_axis = 1
rotation_blob_slider = 0
resize_init_blob_slider = 1
resize_init_blob_slider_long_axis = 1
resize_init_blob_slider_short_axis = 1
init_ellipse_parameter = "[0.5, 0.5, 0.2, 0.2, 180]"
init_object_image = None
remove_blob_box = False
return img, original_image, editable_blob, selected_points, tracking_points, ellipse_lists, ori_result_gallery, object_image_gallery, edited_result_gallery, results_gallery, blobnet_control_strength, blobnet_control_guidance_start, blobnet_control_guidance_end, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider, resize_init_blob_slider, resize_init_blob_slider_long_axis, resize_init_blob_slider_short_axis, init_ellipse_parameter, init_object_image, remove_blob_box
# user click the image to get points, and show the points on the image
def segmentation(img, sel_pix):
# online show seg mask
points = []
labels = []
for p, l in sel_pix:
points.append(p)
labels.append(l)
mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
with torch.no_grad():
masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255
for i in range(3):
output_mask[masks[0] == True, i] = 0.0
mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
mask_all[masks[0] == True, i] = color_mask[i]
masked_img = img / 255 * 0.3 + mask_all * 0.7
masked_img = masked_img*255
## draw points
for point, label in sel_pix:
cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
return masked_img, output_mask
def get_point(img, sel_pix, evt: gr.SelectData):
sel_pix.append((evt.index, 1)) # default foreground_point
# online show seg mask
masked_img, output_mask = segmentation(img, sel_pix)
return masked_img.astype(np.uint8), output_mask
def tracking_points_for_blob(original_image,
tracking_points,
ellipse_lists,
height,
width,
edit_status=True):
sel_pix_transparent_layer = np.zeros((height, width, 4))
sel_ell_transparent_layer = np.zeros((height, width, 4))
start_ellipse, start_transform_param, start_blob_edited_type = ellipse_lists[0]
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1]
## plot start point
start_point = tracking_points[0]
cv2.drawMarker(sel_pix_transparent_layer, start_point, rgba_colors[-1], markerType=markers[1], markerSize=20, thickness=5)
## plot tracking points
if len(tracking_points) > 1:
tracking_points_real = []
for point in tracking_points:
if not tracking_points_real or point != tracking_points_real[-1]:
tracking_points_real.append(point)
for i in range(len(tracking_points_real)-1):
start_point = tracking_points_real[i]
end_point = tracking_points_real[i+1]
vx = end_point[0] - start_point[0]
vy = end_point[1] - start_point[1]
arrow_length = np.sqrt(vx**2 + vy**2)
## draw arrow
if i == len(tracking_points_real)-2:
cv2.arrowedLine(sel_pix_transparent_layer, tuple(start_point), tuple(end_point), rgba_colors[-1], 2, tipLength=8 / arrow_length)
else:
cv2.line(sel_pix_transparent_layer, tuple(start_point), tuple(end_point), rgba_colors[-1], 2,)
if edit_status:
edited_ellipse = move_ellipse(current_ellipse, tracking_points_real)
transform_param = current_transform_param
ellipse_lists.append((edited_ellipse, transform_param, 1))
## draw ellipse, current ellipse need to be rearanged, because the ellipse_lists may be changed
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1]
cv2.ellipse(sel_ell_transparent_layer, current_ellipse, rgba_colors[-1], 2, -1)
# get current ellipse
current_mean, current_cov_matrix = get_gs_from_ellipse(current_ellipse)
current_normalized_mean, current_normalized_cov_matrix = normalize_gs(current_mean, current_cov_matrix, width, height)
current_blob_dict = get_blob_dict_from_norm_gs(current_normalized_mean, current_normalized_cov_matrix)
transparent_background = get_blob_vis_img_from_blob_dict(current_blob_dict, viz_size=(height, width)).convert('RGBA')
## composite images
sel_pix_transparent_layer = Image.fromarray(sel_pix_transparent_layer.astype(np.uint8))
sel_ell_transparent_layer = Image.fromarray(sel_ell_transparent_layer.astype(np.uint8))
transform_gs_img = Image.alpha_composite(transparent_background, sel_pix_transparent_layer)
transform_gs_img = Image.alpha_composite(transform_gs_img, sel_ell_transparent_layer)
## get vis edited image and mask
# Use anti-aliasing to get smoother ellipse edges
original_ellipse_mask_np = np.zeros((height, width), dtype=np.float32)
original_ellipse_mask_np = cv2.ellipse(original_ellipse_mask_np, start_ellipse, 1.0, -1, lineType=cv2.LINE_AA)
original_ellipse_mask_np = (original_ellipse_mask_np * 255).astype(np.uint8)
original_ellipse_mask = Image.fromarray(original_ellipse_mask_np).convert("L")
edited_ellipse_mask_np = np.zeros((height, width), dtype=np.float32)
edited_ellipse_mask_np = cv2.ellipse(edited_ellipse_mask_np, current_ellipse, 1.0, -1, lineType=cv2.LINE_AA)
edited_ellipse_mask_np = (edited_ellipse_mask_np * 255).astype(np.uint8)
edited_ellipse_mask = Image.fromarray(edited_ellipse_mask_np).convert("L")
# import ipdb; ipdb.set_trace()
original_ellipse_masked_image = composite_mask_and_image(original_ellipse_mask, original_image, masked_color=[255,255,255])
edited_ellipse_masked_image = composite_mask_and_image(edited_ellipse_mask, original_ellipse_masked_image, masked_color=[0,0,0])
edited_result_gallery = [edited_ellipse_masked_image, edited_ellipse_mask]
return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery
def add_tracking_points(original_image,
tracking_points,
ellipse_lists,
evt: gr.SelectData): # SelectData is a subclass of EventData
height, width = original_image.shape[:2]
if len(ellipse_lists) == 0:
gr.Warning("Please generate the blob first")
return None, tracking_points, ellipse_lists, None
## get start ellipse
start_ellipse, transform_param, blob_edited_type = ellipse_lists[0]
## check if the point is in the ellipse initially
if not is_point_in_ellipse(evt.index, start_ellipse) and len(tracking_points) == 0:
gr.Warning("Please click the region in the blob in the first time.")
start_mean, start_cov_matrix = get_gs_from_ellipse(start_ellipse)
start_normalized_mean, start_normalized_cov_matrix = normalize_gs(start_mean, start_cov_matrix, width, height)
start_blob_dict = get_blob_dict_from_norm_gs(start_normalized_mean, start_normalized_cov_matrix)
start_transparent_background = get_blob_vis_img_from_blob_dict(start_blob_dict, viz_size=(height, width)).convert('RGBA')
return start_transparent_background, tracking_points, ellipse_lists, None
if len(tracking_points) == 0:
xc, yc = start_ellipse[0]
tracking_points.append([int(xc), int(yc)])
else:
tracking_points.append(evt.index)
tmp_img = original_image.copy()
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img,
tracking_points,
ellipse_lists,
height,
width,
edit_status=True)
return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery
def undo_blob_points(original_image, tracking_points, ellipse_lists):
height, width = original_image.shape[:2]
if len(tracking_points) > 1:
tmp_img = original_image.copy()
tracking_points.pop()
ellipse_lists.pop()
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img,
tracking_points,
ellipse_lists,
height,
width,
edit_status=False)
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1]
# resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = current_transform_param
resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = 1,1,1,0
return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle
else:
if len(tracking_points) == 1:
tracking_points.pop()
else:
gr.Warning("Nothing to Undo")
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = reset_blob_points(original_image, tracking_points, ellipse_lists)
return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle
def reset_blob_points(original_image, tracking_points, ellipse_lists):
edited_result_gallery = None
height, width = original_image.shape[:2]
tracking_points = []
start_ellipse, start_transform_param, start_blob_edited_type = ellipse_lists[0]
ellipse_lists = clear_ellipse_lists(ellipse_lists)
ellipse_lists.append((start_ellipse, start_transform_param, start_blob_edited_type))
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[0]
resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle = current_transform_param
current_mean, current_cov_matrix = get_gs_from_ellipse(current_ellipse)
current_normalized_mean, current_normalized_cov_matrix = normalize_gs(current_mean, current_cov_matrix, width, height)
current_blob_dict = get_blob_dict_from_norm_gs(current_normalized_mean, current_normalized_cov_matrix)
transform_gs_img = get_blob_vis_img_from_blob_dict(current_blob_dict, viz_size=(height, width)).convert('RGBA')
return transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery, resizing_factor_remain_aspect_ratio, resizing_factor_long_axis, resizing_factor_short_axis, anti_clockwise_rotation_angle
def resize_blob(editable_blob,
original_image,
tracking_points,
ellipse_lists,
resizing_factor,
resize_type,
edited_result_gallery,
remove_blob_box):
if remove_blob_box:
gr.Warning("Please use initial blob resize in remove mode to ensure the initial blob surrounds the object")
return editable_blob, ellipse_lists, edited_result_gallery, 1
if len(ellipse_lists) == 0:
gr.Warning("Please generate the blob first")
return None, ellipse_lists, None, 1
if len(tracking_points) == 0:
gr.Warning("Please select the blob first")
return editable_blob, ellipse_lists, None, 1
height, width = original_image.shape[:2]
# resize_type: 0: maintain aspect ratio, 1: along long axis, 2: along short axis
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[-1]
if resize_type == 0:
edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 0)
transform_param = (resizing_factor, current_transform_param[1], current_transform_param[2], current_transform_param[3])
ellipse_lists.append((edited_ellipse, transform_param, 2))
elif resize_type == 1:
edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 1)
transform_param = (current_transform_param[0], resizing_factor, current_transform_param[2], current_transform_param[3])
ellipse_lists.append((edited_ellipse, transform_param, 3))
elif resize_type == 2:
edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 2)
transform_param = (resizing_factor, current_transform_param[1], resizing_factor, current_transform_param[3])
ellipse_lists.append((edited_ellipse, transform_param, 4))
## reset resizing factor, resize is progressive
resizing_factor = 1
if len(tracking_points) > 0:
tracking_points.append(tracking_points[-1])
else:
xc, yc = edited_ellipse[0]
tracking_points.append([int(xc), int(yc)])
tmp_img = original_image.copy()
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img,
tracking_points,
ellipse_lists,
height,
width,
edit_status=False)
return transform_gs_img, ellipse_lists, edited_result_gallery, resizing_factor
def resize_start_blob(editable_blob,
original_image,
tracking_points,
ellipse_lists,
ori_result_gallery,
resizing_factor,
resize_type):
if len(ellipse_lists) == 0:
gr.Warning("Please generate the blob first")
return None, ellipse_lists, None, None, 1
if len(tracking_points) == 0:
gr.Warning("Please select the blob first")
return editable_blob, ellipse_lists, None, None, 1
height, width = original_image.shape[:2]
## resize start blob for background
current_idx = 0
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[current_idx]
if resize_type == 0:
edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 0)
elif resize_type == 1:
edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 1)
elif resize_type == 2:
edited_ellipse, resizing_factor = resize_blob_func(current_ellipse, resizing_factor, height, width, 2)
transform_param = (current_transform_param[0], current_transform_param[1], current_transform_param[2], current_transform_param[3])
ellipse_lists[0] = (edited_ellipse, transform_param, 0)
## reset resizing factor, resize along long axis is progressive
resizing_factor = 1
tmp_img = original_image.copy()
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img,
tracking_points,
ellipse_lists,
height,
width,
edit_status=False)
## ori_result_gallery
gt_i_ellipse_img_path, masked_image_path, mask_image_path, ellipse_mask_path, ellipse_masked_image_path = ori_result_gallery
masked_image = Image.open(masked_image_path[0])
mask_image = Image.open(mask_image_path[0])
## new ellipse mask
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[current_idx]
new_ellipse_mask_img = get_mask_from_ellipse(current_ellipse, height, width)
new_ellipse_masked_img = composite_mask_and_image(new_ellipse_mask_img, tmp_img)
gt_i_ellipse = vis_gt_ellipse_from_ellipse(torch.tensor(tmp_img).round().contiguous().cpu().numpy(),
current_ellipse,
color=[0,255,0])
new_gt_i_ellipse_img = Image.fromarray(gt_i_ellipse.astype(np.uint8))
ori_result_gallery = [new_gt_i_ellipse_img, masked_image, mask_image, new_ellipse_mask_img, new_ellipse_masked_img]
return transform_gs_img, ellipse_lists, edited_result_gallery, ori_result_gallery, resizing_factor
def rotate_blob(editable_blob,
original_image,
tracking_points,
ellipse_lists,
rotation_degree):
if len(ellipse_lists) == 0:
gr.Warning("Please generate the blob first")
return None, ellipse_lists, None, 0
if len(tracking_points) == 0:
gr.Warning("Please select the blob first")
return editable_blob, ellipse_lists, None, 0
height, width = original_image.shape[:2]
current_idx = -1
current_ellipse, current_transform_param, current_blob_edited_type = ellipse_lists[current_idx]
edited_ellipse, rotation_degree = rotate_blob_func(current_ellipse, rotation_degree)
transform_param = (current_transform_param[0], current_transform_param[1], current_transform_param[2], rotation_degree)
ellipse_lists.append((edited_ellipse, transform_param, 5))
rotation_degree = 0
if len(tracking_points) > 0:
tracking_points.append(tracking_points[-1])
else:
xc, yc = edited_ellipse[0]
tracking_points.append([int(xc), int(yc)])
tmp_img = original_image.copy()
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img,
tracking_points,
ellipse_lists,
height,
width,
edit_status=False)
return transform_gs_img, ellipse_lists, edited_result_gallery, rotation_degree
def remove_blob_box_func(editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, remove_blob_box):
if remove_blob_box:
return resize_start_blob(editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, 1.2, 0)
else:
return resize_start_blob(editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, 1.0, 0)
def set_init_ellipse(original_image, original_mask, edited_result_gallery, ellipse_lists, tracking_points, editable_blob, ori_result_gallery, init_ellipse_parameter):
## if init_ellipse_parameter is not None, use the manual initial ellipse
if init_ellipse_parameter is not None and init_ellipse_parameter != "":
# Parse string input like '[0.5,0.5,0.2,0.2,180]'
params = eval(init_ellipse_parameter)
normalized_xc, normalized_yc, normalized_d1, normalized_d2, angle_clockwise_short_axis = params
height, width = original_image.shape[:2]
max_length = np.sqrt(height**2 + width**2)
ellipse_zero = ((width/2, height/2), (1e-5, 1e-5), 0)
ellipse = ((normalized_xc*width, normalized_yc*height), (normalized_d1*max_length, normalized_d2*max_length), angle_clockwise_short_axis)
original_mask = np.array(get_mask_from_ellipse(ellipse, height, width))
original_mask = np.stack([original_mask, original_mask, original_mask], axis=-1)
ellipse_init = (ellipse_zero, (1, 1, 1, 0), 0)
ellipse_next = (ellipse, (1, 1, 1, 0), 0)
if len(ellipse_lists) == 0:
ellipse_lists.append(ellipse_init)
ellipse_lists.append(ellipse_next)
else:
ellipse_lists = clear_ellipse_lists(ellipse_lists)
ellipse_lists.append(ellipse_init)
ellipse_lists.append(ellipse_next)
tmp_img = original_image.copy()
tracking_points = [[int(ellipse_init[0][0][1]), int(ellipse_init[0][0][0])], [int(ellipse_next[0][0][1]), int(ellipse_next[0][0][0])]]
transform_gs_img, tracking_points, ellipse_lists, edited_result_gallery = tracking_points_for_blob(tmp_img,
tracking_points,
ellipse_lists,
height,
width,
edit_status=False)
## plot masked image
masked_image = composite_mask_and_image(original_mask, original_image)
mask_image = Image.fromarray(original_mask.astype(np.uint8)).convert("L")
## plot ellipse
gt_i_ellipse = vis_gt_ellipse_from_ellipse(torch.tensor(original_image).round().contiguous().cpu().numpy(),
ellipse,
color=[0,255,0])
gt_i_ellipse_img = Image.fromarray(gt_i_ellipse.astype(np.uint8))
ellipse_mask = get_mask_from_ellipse(ellipse, height, width)
ellipse_masked_image = composite_mask_and_image(ellipse_mask, original_image)
ori_result_gallery = [gt_i_ellipse_img, masked_image, mask_image, ellipse_mask, ellipse_masked_image]
return transform_gs_img, edited_result_gallery, ellipse_lists, tracking_points, ori_result_gallery, None
gr.Warning("Please set the valid initial ellipse first")
return editable_blob, edited_result_gallery, ellipse_lists, tracking_points, ori_result_gallery, "[0.5, 0.5, 0.2, 0.2, 180]"
def upload_object_image(object_image, edited_result_gallery, remove_blob_box):
if edited_result_gallery == [] or edited_result_gallery == None:
raise gr.Error("Please generate the blob first")
else:
# Check if image needs resizing
# Resize and crop to 512x512
h, w = object_image.shape[:2]
# First resize so shortest side is 512
scale = 512 / min(h, w)
new_h = int(h * scale)
new_w = int(w * scale)
object_image = cv2.resize(object_image, (new_w, new_h))
# Then crop to 512x512
h, w = object_image.shape[:2]
start_y = (h - 512) // 2
start_x = (w - 512) // 2
object_image = object_image[start_y:start_y+512, start_x:start_x+512]
object_image_gallery = [object_image]
remove_blob_box = False
return object_image_gallery, remove_blob_box
block = gr.Blocks()
with block as demo:
with gr.Row():
with gr.Column():
gr.HTML(head)
gr.Markdown(descriptions)
original_image = gr.State(value=None)
original_mask = gr.State(value=None)
resize_blob_maintain_aspect_ratio_state = gr.State(value=0)
resize_blob_along_long_axis_state = gr.State(value=1)
resize_blob_along_short_axis_state = gr.State(value=2)
selected_points = gr.State([])
tracking_points = gr.State([])
ellipse_lists = gr.State([])
with gr.Row():
with gr.Column():
with gr.Column(elem_id="Input"):
gr.Markdown("## **Step 1: Upload an image and click to segment the object**", show_label=False)
with gr.Row():
input_image = gr.Image(type="numpy", label="input", scale=2, height=576, interactive=True)
with gr.Row(elem_id="Seg"):
undo_seg_button = gr.Button('🔙 Undo Seg', elem_id="undo_btnSEG", scale=1)
gr.Markdown("## **Step 2: Input the scene prompt and 🎩 generate the blob**", show_label=False)
scene_prompt = gr.Textbox(label="Scene Prompt", value="Fill image using foreground and background.")
generate_blob_button = gr.Button("🎩 Generate Blob",elem_id="btn")
gr.Markdown("### 💡 Hint: Adjust the control strength and control timesteps range to balance appearance and flexibility", show_label=False)
blobnet_control_strength = gr.Slider(label="🎚️ Control Strength:", minimum=0, maximum=2.5, value=1.6, step=0.01)
with gr.Row():
blobnet_control_guidance_start = gr.Slider(label="Blobnet Control Timestep Start", minimum=0, maximum=1, step=0.01, value=0)
blobnet_control_guidance_end = gr.Slider(label="Blobnet Control Timestep End", minimum=0, maximum=1, step=0.01, value=0.9)
gr.Markdown("### Click to adjust the diffusion sampling options 👇", show_label=False)
with gr.Accordion("Diffusion Options", open=False, elem_id="accordion1"):
seed = gr.Slider(
label="Seed: ", minimum=0, maximum=2147483647, step=1, value=1248464818, scale=2
)
num_samples = gr.Slider(
label="Num samples", minimum=0, maximum=4, step=1, value=2
)
with gr.Group():
with gr.Row():
guidance_scale = gr.Slider(label="CFG scale", minimum=1, maximum=12, step=0.1, value=7.5)
num_inference_steps = gr.Slider(label="NFE", minimum=1, maximum=100, step=1, value=50)
with gr.Column():
gr.Markdown("### Click to expand more previews 👇", show_label=False)
with gr.Row():
with gr.Accordion("More Previews", open=False, elem_id="accordion2"):
with gr.Row():
with gr.Column():
with gr.Tab(elem_classes="feedback", label="Object Image"):
object_image_gallery = gr.Gallery(label='Object Image', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True)
with gr.Column():
with gr.Tab(elem_classes="feedback", label="Original Preview"):
ori_result_gallery = gr.Gallery(label='Original Preview', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True)
gr.Markdown("## **Step 3: Edit the blob, such as move/resize/remove the blob**", show_label=False)
with gr.Row():
with gr.Column():
with gr.Tab(elem_classes="feedback", label="Editable Blob"):
editable_blob = gr.Image(label="Editable Blob", height=320, interactive=False, container=True)
with gr.Column():
with gr.Tab(elem_classes="feedback", label="Edited Preview"):
edited_result_gallery = gr.Gallery(label='Edited Preview', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True)
gr.Markdown("### Click to adjust the target blob size 👇", show_label=False)
with gr.Row():
with gr.Group():
resize_blob_slider_maintain_aspect_ratio = gr.Slider(label="Resize Blob (Maintain Aspect Ratio)", minimum=0.1, maximum=2, step=0.05, value=1)
with gr.Row():
undo_blob_button = gr.Button('🔙 Undo Blob', elem_id="undo_btnBlob", scale=1)
reset_blob_button = gr.Button('🔄 Reset Blob', elem_id="reset_btnBlob", scale=1)
gr.Markdown("### Click to adjust the initial blob size to ensure it surrounds the object👇", show_label=False)
with gr.Group():
with gr.Row():
resize_init_blob_slider = gr.Slider(label="Resize Initial Blob (Maintain Aspect Ratio)", minimum=0.1, maximum=2, step=0.05, value=1, scale=4)
with gr.Row():
remove_blob_box = gr.Checkbox(label="Remove Blob", value=False, scale=1)
gr.Markdown("### Click to achieve more edit types, such as single-sided resize, composition, etc. 👇", show_label=False)
with gr.Accordion("More Edit Types", open=False, elem_id="accordion3"):
gr.Markdown("### slide to achieve single-sided resize and rotation", show_label=False)
with gr.Group():
with gr.Row():
resize_blob_slider_along_long_axis = gr.Slider(label="Resize Blob (Along Long Axis)", minimum=0, maximum=2, step=0.05, value=1)
resize_blob_slider_along_short_axis = gr.Slider(label="Resize Blob (Along Short Axis)", minimum=0, maximum=2, step=0.05, value=1)
with gr.Row():
rotation_blob_slider = gr.Slider(label="Rotate Blob (Clockwise)", minimum=-180, maximum=180, step=1, value=0)
gr.Markdown("### slide to adjust the initial blob (single-sided)", show_label=False)
with gr.Group():
with gr.Row():
resize_init_blob_slider_long_axis = gr.Slider(label="Resize Initial Blob (Long Axis)", minimum=0, maximum=2, step=0.01, value=1)
resize_init_blob_slider_short_axis = gr.Slider(label="Resize Initial Blob (Short Axis)", minimum=0, maximum=2, step=0.01, value=1)
gr.Markdown("### 🎨 Click to set the initial blob and upload object image for compositional generation👇", show_label=False)
with gr.Accordion("Compositional Generation", open=False, elem_id="accordion5"):
with gr.Row():
init_ellipse_parameter = gr.Textbox(label="Initial Ellipse", value="[0.5, 0.5, 0.2, 0.2, 180]", scale=4)
init_ellipse_button = gr.Button("Set Initial Ellipse", elem_id="set_init_ellipse_btn", scale=1)
with gr.Row(elem_id="Image"):
with gr.Tab(elem_classes="feedback1", label="User-specified Object Image"):
init_object_image = gr.Image(type="numpy", label="User-specified Object Image", height=320)
gr.Markdown("## **Step 4: 🚀 Run Generation**", show_label=False)
run_button = gr.Button("🚀 Run Generation",elem_id="btn")
with gr.Row():
with gr.Tab(elem_classes="feedback", label="Results"):
results_gallery = gr.Gallery(label='Results', height=320, elem_id="gallery", show_label=True, interactive=False, preview=True)
eg_index = gr.Textbox(label="Example Index", value="", visible=False)
with gr.Row():
examples_inputs = [
input_image,
scene_prompt,
blobnet_control_strength,
blobnet_control_guidance_start,
blobnet_control_guidance_end,
seed,
eg_index,
]
examples_outputs = [
object_image_gallery,
ori_result_gallery,
editable_blob,
edited_result_gallery,
results_gallery,
ellipse_lists,
tracking_points,
original_image,
remove_blob_box,
]
def process_example(input_image,
scene_prompt,
blobnet_control_strength,
blobnet_control_guidance_start,
blobnet_control_guidance_end,
seed,
eg_index):
eg_index = int(eg_index)
# Force reload images from disk each time
object_image_gallery = [Image.open(path).copy() for path in OBJECT_IMAGE_GALLERY[eg_index]]
ori_result_gallery = [Image.open(path).copy() for path in ORI_RESULT_GALLERY[eg_index]]
editable_blob = Image.open(EDITABLE_BLOB[eg_index]).copy()
edited_result_gallery = [Image.open(path).copy() for path in EDITED_RESULT_GALLERY[eg_index]]
results_gallery = [path for path in RESULTS_GALLERY[eg_index]] # Paths only
# Deep copy mutable data structures
ellipse_lists = copy.deepcopy(ELLIPSE_LISTS[eg_index])
tracking_points = copy.deepcopy(TRACKING_POINTS[eg_index])
# Force reload input image
original_image = np.array(Image.open(INPUT_IMAGE[eg_index]).copy())
remove_blob_box = REMOVE_STATE[eg_index]
return object_image_gallery, ori_result_gallery, editable_blob, edited_result_gallery, results_gallery, ellipse_lists, tracking_points, original_image, remove_blob_box
example = gr.Examples(
label="Quick Example",
examples=EXAMPLES,
inputs=examples_inputs,
outputs=examples_outputs,
fn=process_example,
examples_per_page=10,
cache_examples=False,
run_on_click=True,
)
with gr.Row():
gr.Markdown(citation)
## initial
initial_output = [
input_image,
original_image,
editable_blob,
selected_points,
tracking_points,
ellipse_lists,
ori_result_gallery,
object_image_gallery,
edited_result_gallery,
results_gallery,
blobnet_control_strength,
blobnet_control_guidance_start,
blobnet_control_guidance_end,
resize_blob_slider_maintain_aspect_ratio,
resize_blob_slider_along_long_axis,
resize_blob_slider_along_short_axis,
rotation_blob_slider,
resize_init_blob_slider,
resize_init_blob_slider_long_axis,
resize_init_blob_slider_short_axis,
init_ellipse_parameter,
init_object_image,
remove_blob_box,
]
input_image.upload(
initialize_img,
[input_image],
initial_output
)
## select point
input_image.select(
get_point,
[original_image, selected_points],
[input_image, original_mask],
)
undo_seg_button.click(
undo_seg_points,
[original_image, selected_points],
[input_image, original_mask]
)
## blob image and tracking points: move
editable_blob.select(
add_tracking_points,
[original_image, tracking_points, ellipse_lists],
[editable_blob, tracking_points, ellipse_lists, edited_result_gallery]
)
## undo, reset and save blob
undo_blob_button.click(
undo_blob_points,
[original_image, tracking_points, ellipse_lists],
[editable_blob, tracking_points, ellipse_lists, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider]
)
reset_blob_button.click(
reset_blob_points,
[original_image, tracking_points, ellipse_lists],
[editable_blob, tracking_points, ellipse_lists, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio]
)
## generate blob
generate_blob_button.click(fn=generate_blob,
inputs=[original_image, original_mask, selected_points, ellipse_lists],
outputs=[editable_blob, ori_result_gallery, object_image_gallery, ellipse_lists, tracking_points, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio, resize_blob_slider_along_long_axis, resize_blob_slider_along_short_axis, rotation_blob_slider, resize_init_blob_slider, resize_init_blob_slider_long_axis, resize_init_blob_slider_short_axis, init_ellipse_parameter, init_object_image])
## resize blob
resize_blob_slider_maintain_aspect_ratio.release(
resize_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, resize_blob_slider_maintain_aspect_ratio, resize_blob_maintain_aspect_ratio_state, edited_result_gallery, remove_blob_box],
[editable_blob, ellipse_lists, edited_result_gallery, resize_blob_slider_maintain_aspect_ratio]
)
resize_blob_slider_along_long_axis.release(
resize_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, resize_blob_slider_along_long_axis, resize_blob_along_long_axis_state, edited_result_gallery, remove_blob_box],
[editable_blob, ellipse_lists, edited_result_gallery, resize_blob_slider_along_long_axis]
)
resize_blob_slider_along_short_axis.release(
resize_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, resize_blob_slider_along_short_axis, resize_blob_along_short_axis_state, edited_result_gallery, remove_blob_box],
[editable_blob, ellipse_lists, edited_result_gallery, resize_blob_slider_along_short_axis]
)
## rotate blob
rotation_blob_slider.release(
rotate_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, rotation_blob_slider],
[editable_blob, ellipse_lists, edited_result_gallery, rotation_blob_slider]
)
remove_blob_box.change(
remove_blob_box_func,
[editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery],
[editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_blob_slider_maintain_aspect_ratio]
)
## resize init blob
resize_init_blob_slider.release(
resize_start_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, resize_init_blob_slider, resize_blob_maintain_aspect_ratio_state],
[editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_init_blob_slider]
)
resize_init_blob_slider_long_axis.release(
resize_start_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, resize_init_blob_slider_long_axis, resize_blob_along_long_axis_state],
[editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_init_blob_slider_long_axis]
)
resize_init_blob_slider_short_axis.release(
resize_start_blob,
[editable_blob, original_image, tracking_points, ellipse_lists, ori_result_gallery, resize_init_blob_slider_short_axis, resize_blob_along_short_axis_state],
[editable_blob, ellipse_lists, edited_result_gallery, ori_result_gallery, resize_init_blob_slider_short_axis]
)
## set initial ellipse
init_ellipse_button.click(
set_init_ellipse,
inputs=[original_image, original_mask, edited_result_gallery, ellipse_lists, tracking_points, editable_blob, ori_result_gallery, init_ellipse_parameter],
outputs=[editable_blob, edited_result_gallery, ellipse_lists, tracking_points, ori_result_gallery, init_ellipse_parameter]
)
## upload user-specified object image
init_object_image.upload(
upload_object_image,
[init_object_image, edited_result_gallery],
[object_image_gallery, remove_blob_box]
)
## run BlobEdit
ips = [
original_image,
scene_prompt,
ori_result_gallery,
object_image_gallery,
edited_result_gallery,
ellipse_lists,
blobnet_control_strength,
blobnet_control_guidance_start,
blobnet_control_guidance_end,
remove_blob_box,
num_samples,
seed,
guidance_scale,
num_inference_steps,
# for save
editable_blob,
resize_blob_slider_maintain_aspect_ratio,
resize_blob_slider_along_long_axis,
resize_blob_slider_along_short_axis,
rotation_blob_slider,
resize_init_blob_slider,
resize_init_blob_slider_long_axis,
resize_init_blob_slider_short_axis,
tracking_points,
]
run_button.click(
run_function,
ips,
[results_gallery]
)
## if have a localhost access error, try to use the following code
# demo.launch(server_name="0.0.0.0", server_port=12346)
demo.launch()