Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import torch | |
| from diffusers.utils import check_min_version | |
| from pipeline_objectclear import ObjectClearPipeline | |
| from tools.download_util import load_file_from_url | |
| from tools.painter import mask_painter | |
| import argparse | |
| import numpy as np | |
| import torchvision.transforms.functional as TF | |
| from scipy.ndimage import convolve, zoom | |
| import spaces | |
| from utils import resize_by_short_side | |
| from tools.interact_tools import SamControler | |
| from tools.misc import get_device | |
| import json | |
| check_min_version("0.30.2") | |
| def parse_augment(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--device', type=str, default=None) | |
| parser.add_argument('--sam_model_type', type=str, default="vit_h") | |
| parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") | |
| args = parser.parse_args() | |
| if not args.device: | |
| args.device = str(get_device()) | |
| return args | |
| # convert points input to prompt state | |
| def get_prompt(click_state, click_input): | |
| inputs = json.loads(click_input) | |
| points = click_state[0] | |
| labels = click_state[1] | |
| for input in inputs: | |
| points.append(input[:2]) | |
| labels.append(input[2]) | |
| click_state[0] = points | |
| click_state[1] = labels | |
| prompt = { | |
| "prompt_type":["click"], | |
| "input_point":click_state[0], | |
| "input_label":click_state[1], | |
| "multimask_output":"True", | |
| } | |
| return prompt | |
| # use sam to get the mask | |
| def sam_refine(image_state, point_prompt, click_state, evt:gr.SelectData): | |
| if point_prompt == "Positive": | |
| coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) | |
| else: | |
| coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) | |
| # prompt for sam model | |
| model.samcontroler.sam_controler.reset_image() | |
| model.samcontroler.sam_controler.set_image(image_state["origin_image"]) | |
| prompt = get_prompt(click_state=click_state, click_input=coordinate) | |
| mask, logit, painted_image = model.first_frame_click( | |
| image=image_state["origin_image"], | |
| points=np.array(prompt["input_point"]), | |
| labels=np.array(prompt["input_label"]), | |
| multimask=prompt["multimask_output"], | |
| ) | |
| image_state["mask"] = mask | |
| image_state["logit"] = logit | |
| image_state["painted_image"] = painted_image | |
| return painted_image, image_state, click_state | |
| def add_multi_mask(image_state, interactive_state, mask_dropdown): | |
| mask = image_state["mask"] | |
| interactive_state["masks"].append(mask) | |
| interactive_state["mask_names"].append("mask_{:03d}".format(len(interactive_state["masks"]))) | |
| mask_dropdown.append("mask_{:03d}".format(len(interactive_state["masks"]))) | |
| select_frame = show_mask(image_state, interactive_state, mask_dropdown) | |
| return interactive_state, gr.update(choices=interactive_state["mask_names"], value=mask_dropdown), select_frame, [[],[]] | |
| def clear_click(image_state, click_state): | |
| click_state = [[],[]] | |
| input_image = image_state["origin_image"] | |
| return input_image, click_state | |
| def remove_multi_mask(interactive_state, click_state, image_state): | |
| interactive_state["mask_names"]= [] | |
| interactive_state["masks"] = [] | |
| click_state = [[],[]] | |
| input_image = image_state["origin_image"] | |
| return interactive_state, gr.update(choices=[],value=[]), input_image, click_state | |
| def show_mask(image_state, interactive_state, mask_dropdown): | |
| mask_dropdown.sort() | |
| if image_state["origin_image"] is not None: | |
| select_frame = image_state["origin_image"] | |
| for i in range(len(mask_dropdown)): | |
| mask_number = int(mask_dropdown[i].split("_")[1]) - 1 | |
| mask = interactive_state["masks"][mask_number] | |
| select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) | |
| return select_frame | |
| def upload_and_reset(image_input, interactive_state): | |
| click_state = [[], []] | |
| interactive_state["mask_names"]= [] | |
| interactive_state["masks"] = [] | |
| image_state, image_info, image_input = update_image_state_on_upload(image_input) | |
| return ( | |
| image_state, | |
| image_info, | |
| image_input, | |
| interactive_state, | |
| click_state, | |
| gr.update(choices=[], value=[]), | |
| ) | |
| def update_image_state_on_upload(image_input): | |
| frame = image_input | |
| image_size = (frame.size[1], frame.size[0]) | |
| frame_np = np.array(frame) | |
| image_state = { | |
| "origin_image": frame_np, | |
| "painted_image": frame_np.copy(), | |
| "mask": np.zeros((image_size[0], image_size[1]), np.uint8), | |
| "logit": None, | |
| } | |
| image_info = f"Image Name: uploaded.png,\nImage Size: {image_size}" | |
| model.samcontroler.sam_controler.reset_image() | |
| model.samcontroler.sam_controler.set_image(frame_np) | |
| return image_state, image_info, image_input | |
| # SAM generator | |
| class MaskGenerator(): | |
| def __init__(self, sam_checkpoint, args): | |
| self.args = args | |
| self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) | |
| def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): | |
| mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) | |
| return mask, logit, painted_image | |
| # args, defined in track_anything.py | |
| args = parse_augment() | |
| sam_checkpoint_url_dict = { | |
| 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", | |
| 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", | |
| 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" | |
| } | |
| checkpoint_folder = os.path.join('/home/user/app/', 'pretrained_models') | |
| sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_folder) | |
| # initialize sams | |
| model = MaskGenerator(sam_checkpoint, args) | |
| # Build pipeline | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| pipe = ObjectClearPipeline.from_pretrained_with_custom_modules( | |
| "jixin0101/ObjectClear", | |
| torch_dtype=torch.float16, | |
| variant='fp16', | |
| apply_attention_guided_fusion=True | |
| ) | |
| pipe.to(device) | |
| def process(image_state, interactive_state, mask_dropdown, guidance_scale, seed, num_inference_steps | |
| ): | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| image_np = image_state["origin_image"] | |
| image = Image.fromarray(image_np) | |
| if interactive_state["masks"]: | |
| if len(mask_dropdown) == 0: | |
| mask_dropdown = ["mask_001"] | |
| mask_dropdown.sort() | |
| template_mask = interactive_state["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) | |
| for i in range(1,len(mask_dropdown)): | |
| mask_number = int(mask_dropdown[i].split("_")[1]) - 1 | |
| template_mask = np.clip(template_mask+interactive_state["masks"][mask_number]*(mask_number+1), 0, mask_number+1) | |
| image_state["mask"]= template_mask | |
| else: | |
| template_mask = image_state["mask"] | |
| mask = Image.fromarray((template_mask).astype(np.uint8) * 255) | |
| image_or = image.copy() | |
| image = image.convert("RGB") | |
| mask = mask.convert("RGB") | |
| image = resize_by_short_side(image, 512, resample=Image.BICUBIC) | |
| mask = resize_by_short_side(mask, 512, resample=Image.NEAREST) | |
| w, h = image.size | |
| result = pipe( | |
| prompt="remove the instance of object", | |
| image=image, | |
| mask_image=mask, | |
| generator=generator, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=h, | |
| width=w, | |
| ) | |
| fused_img_pil = result.images[0] | |
| return fused_img_pil.resize((image_or.size[:2])), (image.resize((image_or.size[:2])), fused_img_pil.resize((image_or.size[:2]))) | |
| import base64 | |
| with open("./Logo.png", "rb") as f: | |
| img_bytes = f.read() | |
| img_b64 = base64.b64encode(img_bytes).decode() | |
| html_img = f''' | |
| <div style="display:flex; justify-content:center; align-items:center; width:100%;"> | |
| <img src="data:image/png;base64,{img_b64}" style="border:none; width:200px; height:auto;"/> | |
| </div> | |
| ''' | |
| tutorial_url = "https://github.com/zjx0101/ObjectClear/releases/download/media/tutorial.mp4" | |
| assets_path = os.path.join('/home/user/app/hugging_face/', "assets/") | |
| load_file_from_url(tutorial_url, assets_path) | |
| description = r""" | |
| <b>Official Gradio demo</b> for <a href='https://github.com/zjx0101/ObjectClear' target='_blank'><b>ObjectClear: Complete Object Removal via Object-Effect Attention</b></a>.<br> | |
| 🔥 ObjectClear is an object removal model that can jointly eliminate the target object and its associated effects leveraging Object-Effect Attention, while preserving background consistency.<br> | |
| 🖼️ Try to drop your image, assign the target masks with a few clicks, and get the object removal results!<br> | |
| *Note: All input images are temporarily resized (shorter side = 512 pixels) during inference to match the training resolution. Final outputs are restored to the original resolution.<br>* | |
| """ | |
| article = r"""<h3> | |
| <b>If ObjectClear is helpful, please help to star the <a href='https://github.com/zjx0101/ObjectClear' target='_blank'>Github Repo</a>. Thanks!</b></h3> | |
| <hr> | |
| 📑 **Citation** | |
| <br> | |
| If our work is useful for your research, please consider citing: | |
| ```bibtex | |
| @InProceedings{zhao2025ObjectClear, | |
| title = {{ObjectClear}: Complete Object Removal via Object-Effect Attention}, | |
| author = {Zhao, Jixin and Zhou, Shangchen and Wang, Zhouxia and Yang, Peiqing and Loy, Chen Change}, | |
| booktitle = {arXiv preprint arXiv:2505.22636}, | |
| year = {2025} | |
| } | |
| ``` | |
| 📧 **Contact** | |
| <br> | |
| If you have any questions, please feel free to reach me out at <b>[email protected]</b>. | |
| <br> | |
| 👏 **Acknowledgement** | |
| <br> | |
| This demo is adapted from [MatAnyone](https://github.com/pq-yang/MatAnyone), and leveraging segmentation capabilities from [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks for their awesome works! | |
| """ | |
| custom_css = """ | |
| #input-image { | |
| aspect-ratio: 1 / 1; | |
| width: 100%; | |
| max-width: 100%; | |
| height: auto; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| } | |
| #input-image img { | |
| max-width: 100%; | |
| max-height: 100%; | |
| object-fit: contain; | |
| display: block; | |
| } | |
| #main-columns { | |
| gap: 60px; | |
| } | |
| #main-columns > .gr-column { | |
| flex: 1; | |
| } | |
| #compare-image { | |
| width: 100%; | |
| aspect-ratio: 1 / 1; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| margin: 0; | |
| padding: 0; | |
| max-width: 100%; | |
| box-sizing: border-box; | |
| } | |
| #compare-image svg.svelte-zyxd38 { | |
| position: absolute !important; | |
| top: 50% !important; | |
| left: 50% !important; | |
| transform: translate(-50%, -50%) !important; | |
| } | |
| #compare-image .icon.svelte-1oiin9d { | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| } | |
| #compare-image { | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .new_button {background-color: #171717 !important; color: #ffffff !important; border: none !important;} | |
| .new_button:hover {background-color: #4b4b4b !important;} | |
| #start-button { | |
| background: linear-gradient(135deg, #2575fc 0%, #6a11cb 100%); | |
| color: white; | |
| border: none; | |
| padding: 12px 24px; | |
| font-size: 16px; | |
| font-weight: bold; | |
| border-radius: 12px; | |
| cursor: pointer; | |
| box-shadow: 0 0 12px rgba(100, 100, 255, 0.7); | |
| transition: all 0.3s ease; | |
| } | |
| #start-button:hover { | |
| transform: scale(1.05); | |
| box-shadow: 0 0 20px rgba(100, 100, 255, 1); | |
| } | |
| <style> | |
| .button-wrapper { | |
| width: 30%; | |
| text-align: center; | |
| } | |
| .wide-button { | |
| width: 83% !important; | |
| background-color: black !important; | |
| color: white !important; | |
| border: none !important; | |
| padding: 8px 0 !important; | |
| font-size: 16px !important; | |
| display: inline-block; | |
| margin: 30px 0px 0px 50px ; | |
| } | |
| .wide-button:hover { | |
| background-color: #656262 !important; | |
| } | |
| </style> | |
| """ | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.HTML(html_img) | |
| gr.Markdown(description) | |
| with gr.Group(elem_classes="gr-monochrome-group", visible=True): | |
| with gr.Row(): | |
| with gr.Accordion('SAM Settings (click to expand)', open=False): | |
| with gr.Row(): | |
| point_prompt = gr.Radio( | |
| choices=["Positive", "Negative"], | |
| value="Positive", | |
| label="Point Prompt", | |
| info="Click to add positive or negative point for target mask", | |
| interactive=True, | |
| min_width=100, | |
| scale=1) | |
| mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask Selection", info="Choose 1~all mask(s) added in Step 2") | |
| with gr.Row(elem_id="main-columns"): | |
| with gr.Column(): | |
| click_state = gr.State([[],[]]) | |
| interactive_state = gr.State( | |
| { | |
| "mask_names": [], | |
| "masks": [] | |
| } | |
| ) | |
| image_state = gr.State( | |
| { | |
| "origin_image": None, | |
| "painted_image": None, | |
| "mask": None, | |
| "logit": None | |
| } | |
| ) | |
| image_info = gr.Textbox(label="Image Info", visible=False) | |
| input_image = gr.Image( | |
| label='Input', | |
| type='pil', | |
| sources=["upload"], | |
| image_mode='RGB', | |
| interactive=True, | |
| elem_id="input-image" | |
| ) | |
| with gr.Row(equal_height=True, elem_classes="mask_button_group"): | |
| clear_button_click = gr.Button(value="Clear Clicks",elem_classes="new_button", min_width=100) | |
| add_mask_button = gr.Button(value="Add Mask", elem_classes="new_button", min_width=100) | |
| remove_mask_button = gr.Button(value="Delete Mask", elem_classes="new_button", min_width=100) | |
| submit_button_component = gr.Button( | |
| value='Start ObjectClear', elem_id="start-button" | |
| ) | |
| with gr.Accordion('ObjectClear Settings', open=True): | |
| guidance_scale = gr.Slider( | |
| minimum=1, maximum=10, step=0.5, value=2.5, | |
| label="Guidance Scale", | |
| info="Higher = stronger removal; lower = better background preservation (default: 2.5)" | |
| ) | |
| seed = gr.Slider( | |
| minimum=0, maximum=1000000, step=1, value=300000, | |
| label="Seed Value", | |
| info="Different seeds can lead to noticeably different object removal results (default: 300000)" | |
| ) | |
| num_inference_steps = gr.Slider( | |
| minimum=1, maximum=40, step=1, value=20, | |
| label="Num Inference Steps", | |
| info="Higher values may improve quality but take longer (default: 20)" | |
| ) | |
| with gr.Column(): | |
| output_image_component = gr.Image( | |
| type='pil', image_mode='RGB', label='Output', format="png", elem_id="input-image") | |
| output_compare_image_component = gr.ImageSlider( | |
| label="Comparison", | |
| type="pil", | |
| format='png', | |
| elem_id="compare-image" | |
| ) | |
| input_image.upload( | |
| fn=upload_and_reset, | |
| inputs=[input_image, interactive_state], | |
| outputs=[ | |
| image_state, | |
| image_info, | |
| input_image, | |
| interactive_state, | |
| click_state, | |
| mask_dropdown, | |
| ] | |
| ) | |
| # click select image to get mask using sam | |
| input_image.select( | |
| fn=sam_refine, | |
| inputs=[image_state, point_prompt, click_state], | |
| outputs=[input_image, image_state, click_state] | |
| ) | |
| # add different mask | |
| add_mask_button.click( | |
| fn=add_multi_mask, | |
| inputs=[image_state, interactive_state, mask_dropdown], | |
| outputs=[interactive_state, mask_dropdown, input_image, click_state] | |
| ) | |
| remove_mask_button.click( | |
| fn=remove_multi_mask, | |
| inputs=[interactive_state, click_state, image_state], | |
| outputs=[interactive_state, mask_dropdown, input_image, click_state] | |
| ) | |
| # points clear | |
| clear_button_click.click( | |
| fn = clear_click, | |
| inputs = [image_state, click_state,], | |
| outputs = [input_image, click_state], | |
| ) | |
| submit_button_component.click( | |
| fn=process, | |
| inputs=[ | |
| image_state, | |
| interactive_state, | |
| mask_dropdown, | |
| guidance_scale, | |
| seed, | |
| num_inference_steps | |
| ], | |
| outputs=[ | |
| output_image_component, output_compare_image_component | |
| ] | |
| ) | |
| with gr.Accordion("📕 Video Tutorial (click to expand)", open=False, elem_classes="custom-bg"): | |
| with gr.Row(): | |
| gr.Video(value="/home/user/app/hugging_face/assets/tutorial.mp4", elem_classes="video") | |
| gr.Markdown("---") | |
| gr.Markdown("## Examples") | |
| example_images = [ | |
| os.path.join(os.path.dirname(__file__), "examples", f"test{i}.png") | |
| for i in range(10) | |
| ] | |
| examples_data = [ | |
| [example_images[i], None] for i in range(len(example_images)) | |
| ] | |
| examples = gr.Examples( | |
| examples=examples_data, | |
| inputs=[input_image, interactive_state], | |
| outputs=[image_state, image_info, input_image, | |
| interactive_state, click_state, mask_dropdown], | |
| fn=upload_and_reset, | |
| run_on_click=True, | |
| cache_examples=False, | |
| label="Click below to load example images" | |
| ) | |
| gr.Markdown(article) | |
| def pre_update_input_image(): | |
| return gr.update(value=None) | |
| demo.load( | |
| fn=pre_update_input_image, | |
| inputs=[], | |
| outputs=[input_image] | |
| ) | |
| demo.launch(debug=True, show_error=True) |