Spaces:
Runtime error
Runtime error
Ubuntu
commited on
Commit
Β·
e5efca7
1
Parent(s):
c071a86
Update Inpainting Demo
Browse files- .gitignore +1 -0
- .log/log.txt +6 -0
- SegFormer +1 -0
- output.png +0 -0
- requirements.txt +2 -2
- test.png +0 -0
- test.py +168 -76
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
__pycache__
|
| 2 |
*.pyc
|
| 3 |
checkpoints/
|
|
|
|
| 4 |
*.pth
|
|
|
|
| 1 |
__pycache__
|
| 2 |
*.pyc
|
| 3 |
checkpoints/
|
| 4 |
+
I2SB/
|
| 5 |
*.pth
|
.log/log.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[19:02:29] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
|
| 2 |
+
INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
|
| 3 |
+
[19:02:33] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
|
| 4 |
+
[19:02:44] INFO (0:00:14) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
|
| 5 |
+
[19:02:49] INFO (0:00:19) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
| 6 |
+
[19:02:50] INFO (0:00:20) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
SegFormer
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 64ab11278eb30b8e2d8ea1d10a777fc5b1563948
|
output.png
ADDED
|
requirements.txt
CHANGED
|
@@ -18,8 +18,8 @@ timm
|
|
| 18 |
# torch==2.0.0
|
| 19 |
# torchvision==0.15.1
|
| 20 |
|
| 21 |
-
torch==2.2.1
|
| 22 |
-
torchvision==0.17.1
|
| 23 |
|
| 24 |
gevent
|
| 25 |
yapf
|
|
|
|
| 18 |
# torch==2.0.0
|
| 19 |
# torchvision==0.15.1
|
| 20 |
|
| 21 |
+
# torch==2.2.1
|
| 22 |
+
# torchvision==0.17.1
|
| 23 |
|
| 24 |
gevent
|
| 25 |
yapf
|
test.png
ADDED
|
test.py
CHANGED
|
@@ -36,6 +36,34 @@ from GroundingDINO.groundingdino.util import box_ops
|
|
| 36 |
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
| 37 |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
import cv2
|
| 40 |
import numpy as np
|
| 41 |
import matplotlib
|
|
@@ -126,6 +154,30 @@ kosmos_processor = None
|
|
| 126 |
colors = [(255, 0, 0), (0, 255, 0)]
|
| 127 |
markers = [1, 5]
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
def get_point(img, sel_pix, evt: gr.SelectData):
|
| 130 |
img = np.array(img, dtype=np.uint8)
|
| 131 |
sel_pix.append(evt.index)
|
|
@@ -146,6 +198,10 @@ def undo_button(orig_img, sel_pix):
|
|
| 146 |
for point in sel_pix:
|
| 147 |
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
| 148 |
return Image.fromarray(temp).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def toggle_button(orig_img, task_type):
|
| 151 |
print(task_type)
|
|
@@ -173,6 +229,37 @@ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
|
| 173 |
_ = model.eval()
|
| 174 |
return model
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
def plot_boxes_to_image(image_pil, tgt):
|
| 177 |
H, W = tgt["size"]
|
| 178 |
boxes = tgt["boxes"]
|
|
@@ -238,6 +325,8 @@ def load_image(image_path):
|
|
| 238 |
image, _ = transform(image_pil, None) # 3, h, w
|
| 239 |
return image_pil, image
|
| 240 |
|
|
|
|
|
|
|
| 241 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
| 242 |
caption = caption.lower()
|
| 243 |
caption = caption.strip()
|
|
@@ -357,6 +446,24 @@ def load_sd_model(device):
|
|
| 357 |
torch_dtype=torch.float16,
|
| 358 |
)
|
| 359 |
sd_model = sd_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
| 362 |
try:
|
|
@@ -511,7 +618,7 @@ def concatenate_images_vertical(image1, image2):
|
|
| 511 |
return new_image
|
| 512 |
|
| 513 |
mask_source_draw = "draw a mask on input image"
|
| 514 |
-
mask_source_segment = "
|
| 515 |
|
| 516 |
def get_time_cost(run_task_time, time_cost_str):
|
| 517 |
now_time = int(time.time()*1000)
|
|
@@ -524,11 +631,8 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
| 524 |
run_task_time = now_time
|
| 525 |
return run_task_time, time_cost_str
|
| 526 |
|
| 527 |
-
def run_anything_task(input_image, input_points, origin_image,
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
|
| 531 |
-
inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
|
| 532 |
|
| 533 |
run_task_time = 0
|
| 534 |
time_cost_str = ''
|
|
@@ -543,27 +647,19 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
| 543 |
image_pil, image = load_image(input_image.convert("RGB"))
|
| 544 |
input_img = input_image
|
| 545 |
|
| 546 |
-
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil,
|
| 547 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 548 |
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
| 549 |
|
| 550 |
-
text_prompt = text_prompt.strip()
|
| 551 |
-
# if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
| 552 |
-
# if text_prompt == '':
|
| 553 |
-
# return [], gr.Gallery.update(label='Detection prompt is not found!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 554 |
-
|
| 555 |
if input_image is None:
|
| 556 |
return [], gr.Gallery.update(label='Please upload a image!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 557 |
|
| 558 |
file_temp = int(time.time())
|
| 559 |
-
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/
|
| 560 |
|
| 561 |
output_images = []
|
| 562 |
|
| 563 |
# load image
|
| 564 |
-
if mask_source_radio == mask_source_draw:
|
| 565 |
-
input_mask_pil = input_image['mask']
|
| 566 |
-
input_mask = np.array(input_mask_pil.convert("L"))
|
| 567 |
|
| 568 |
if isinstance(input_image, dict):
|
| 569 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
|
@@ -626,17 +722,17 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
| 626 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 627 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 628 |
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
|
| 629 |
-
if
|
| 630 |
task_type = 'remove'
|
| 631 |
|
| 632 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
| 633 |
if mask_source_radio == mask_source_draw:
|
|
|
|
|
|
|
| 634 |
mask_pil = input_mask_pil
|
| 635 |
mask = input_mask
|
| 636 |
else:
|
| 637 |
masks_ori = copy.deepcopy(masks)
|
| 638 |
-
if inpaint_mode == 'merge':
|
| 639 |
-
masks = torch.sum(masks, dim=0).unsqueeze(0)
|
| 640 |
masks = torch.where(masks > 0, True, False)
|
| 641 |
mask = masks[0][0].cpu().numpy()
|
| 642 |
mask_pil = Image.fromarray(mask)
|
|
@@ -644,18 +740,11 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
| 644 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 645 |
|
| 646 |
if task_type in ['inpainting', 'outpainting']:
|
| 647 |
-
#
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
img_arr = np.array(image_mask_for_inpaint)
|
| 653 |
-
img_arr = np.where(img_arr > 0, 1, img_arr)
|
| 654 |
-
img_arr = 1 - img_arr
|
| 655 |
-
image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
|
| 656 |
-
output_images.append(image_mask_for_inpaint.convert("RGB"))
|
| 657 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 658 |
-
image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 659 |
else:
|
| 660 |
# remove from mask
|
| 661 |
aasds = 1
|
|
@@ -681,8 +770,6 @@ def run_anything_task(input_image, input_points, origin_image, text_prompt, task
|
|
| 681 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 682 |
|
| 683 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
| 684 |
-
text_prompt_visible = True
|
| 685 |
-
inpaint_prompt_visible = False
|
| 686 |
mask_source_radio_visible = False
|
| 687 |
num_relation_visible = False
|
| 688 |
|
|
@@ -693,35 +780,29 @@ def change_radio_display(task_type, mask_source_radio, orig_img):
|
|
| 693 |
print(task_type)
|
| 694 |
if task_type == "Kosmos-2":
|
| 695 |
if kosmos_enable:
|
| 696 |
-
text_prompt_visible = False
|
| 697 |
image_gallery_visible = False
|
| 698 |
kosmos_input_visible = True
|
| 699 |
kosmos_output_visible = True
|
| 700 |
kosmos_text_output_visible = True
|
| 701 |
|
| 702 |
-
if task_type in ['inpainting', 'outpainting']:
|
| 703 |
-
inpaint_prompt_visible = False
|
| 704 |
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
|
| 705 |
mask_source_radio_visible = True
|
| 706 |
-
if mask_source_radio == mask_source_draw:
|
| 707 |
-
text_prompt_visible = False
|
| 708 |
if task_type == "relate anything":
|
| 709 |
-
text_prompt_visible = False
|
| 710 |
num_relation_visible = True
|
| 711 |
if task_type == "segment":
|
| 712 |
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 713 |
elif task_type == "inpainting":
|
| 714 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 715 |
|
| 716 |
-
return (gr.
|
| 717 |
-
gr.Textbox.update(visible=inpaint_prompt_visible),
|
| 718 |
-
gr.Radio.update(visible=mask_source_radio_visible),
|
| 719 |
gr.Slider.update(visible=num_relation_visible),
|
| 720 |
gr.Gallery.update(visible=image_gallery_visible),
|
| 721 |
gr.Radio.update(visible=kosmos_input_visible),
|
| 722 |
gr.Image.update(visible=kosmos_output_visible),
|
| 723 |
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
| 724 |
-
ret, [],
|
|
|
|
|
|
|
| 725 |
|
| 726 |
def get_model_device(module):
|
| 727 |
try:
|
|
@@ -770,42 +851,52 @@ def main_gradio(args):
|
|
| 770 |
[input_image, selected_points],
|
| 771 |
[input_image]
|
| 772 |
)
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
print(dir(input_image))
|
| 781 |
task_type = gr.Radio(task_types, value="segment",
|
| 782 |
label='Task type', visible=True)
|
| 783 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 784 |
-
value=
|
| 785 |
visible=False)
|
| 786 |
-
text_prompt = gr.Textbox(label="Detection", placeholder="Cannot be empty")
|
| 787 |
-
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
| 788 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 789 |
|
| 790 |
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
| 791 |
|
| 792 |
run_button = gr.Button(label="Run", visible=True)
|
| 793 |
-
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
|
| 810 |
with gr.Column():
|
| 811 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
|
@@ -841,15 +932,15 @@ def main_gradio(args):
|
|
| 841 |
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
| 842 |
|
| 843 |
run_button.click(fn=run_anything_task, inputs=[
|
| 844 |
-
input_image, selected_points, original_image,
|
| 845 |
-
|
| 846 |
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
| 847 |
|
| 848 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 849 |
-
outputs=[
|
| 850 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 851 |
-
outputs=[
|
| 852 |
-
image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button
|
| 853 |
])
|
| 854 |
|
| 855 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
@@ -895,8 +986,9 @@ if __name__ == "__main__":
|
|
| 895 |
if sam_enable:
|
| 896 |
load_sam_model(device)
|
| 897 |
|
| 898 |
-
|
| 899 |
-
|
|
|
|
| 900 |
|
| 901 |
# if lama_cleaner_enable:
|
| 902 |
# load_lama_cleaner_model(device)
|
|
|
|
| 36 |
from GroundingDINO.groundingdino.util.slconfig import SLConfig
|
| 37 |
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
| 38 |
|
| 39 |
+
# I2SB
|
| 40 |
+
import sys
|
| 41 |
+
|
| 42 |
+
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
|
| 43 |
+
|
| 44 |
+
import numpy as np
|
| 45 |
+
import torch
|
| 46 |
+
import torch.distributed as dist
|
| 47 |
+
import torchvision.transforms as transforms
|
| 48 |
+
import torchvision.utils as tu
|
| 49 |
+
from easydict import EasyDict as edict
|
| 50 |
+
from fastapi import (Body, Depends, FastAPI, File, Form, HTTPException, Query,
|
| 51 |
+
UploadFile)
|
| 52 |
+
from ipdb import set_trace as debug
|
| 53 |
+
from PIL import Image
|
| 54 |
+
from torch.multiprocessing import Process
|
| 55 |
+
from torch.utils.data import DataLoader, Subset
|
| 56 |
+
from torch_ema import ExponentialMovingAverage
|
| 57 |
+
|
| 58 |
+
import I2SB.distributed_util as dist_util
|
| 59 |
+
from I2SB.corruption import build_corruption
|
| 60 |
+
from I2SB.dataset import air_liquide
|
| 61 |
+
from I2SB.i2sb import Runner, ckpt_util, download_ckpt
|
| 62 |
+
from I2SB.logger import Logger
|
| 63 |
+
from I2SB.sample import *
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
import cv2
|
| 68 |
import numpy as np
|
| 69 |
import matplotlib
|
|
|
|
| 154 |
colors = [(255, 0, 0), (0, 255, 0)]
|
| 155 |
markers = [1, 5]
|
| 156 |
|
| 157 |
+
i2sb_opt = edict(
|
| 158 |
+
distributed=False,
|
| 159 |
+
device="cuda",
|
| 160 |
+
batch_size=1,
|
| 161 |
+
nfe=10,
|
| 162 |
+
dataset="sample",
|
| 163 |
+
dataset_dir=Path(f"dataset/sample"),
|
| 164 |
+
n_gpu_per_node=1,
|
| 165 |
+
use_fp16=False,
|
| 166 |
+
ckpt="inpaint-freeform2030",
|
| 167 |
+
image_size=256,
|
| 168 |
+
partition=None,
|
| 169 |
+
global_size=1,
|
| 170 |
+
global_rank=0,
|
| 171 |
+
clip_denoise=True
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
i2sb_transforms = transforms.Compose([
|
| 175 |
+
transforms.Resize(i2sb_opt.image_size),
|
| 176 |
+
transforms.CenterCrop(i2sb_opt.image_size),
|
| 177 |
+
transforms.ToTensor(),
|
| 178 |
+
transforms.Lambda(lambda t: (t * 2) - 1) # [0,1] --> [-1, 1]
|
| 179 |
+
])
|
| 180 |
+
|
| 181 |
def get_point(img, sel_pix, evt: gr.SelectData):
|
| 182 |
img = np.array(img, dtype=np.uint8)
|
| 183 |
sel_pix.append(evt.index)
|
|
|
|
| 198 |
for point in sel_pix:
|
| 199 |
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
| 200 |
return Image.fromarray(temp).convert("RGB")
|
| 201 |
+
|
| 202 |
+
def clear_button(orig_img):
|
| 203 |
+
|
| 204 |
+
return orig_img, []
|
| 205 |
|
| 206 |
def toggle_button(orig_img, task_type):
|
| 207 |
print(task_type)
|
|
|
|
| 229 |
_ = model.eval()
|
| 230 |
return model
|
| 231 |
|
| 232 |
+
def load_i2sb_model():
|
| 233 |
+
RESULT_DIR = Path("I2SB/results")
|
| 234 |
+
global i2sb_model
|
| 235 |
+
global ckpt_opt
|
| 236 |
+
global corrupt_type
|
| 237 |
+
global nfe
|
| 238 |
+
|
| 239 |
+
s = time.time()
|
| 240 |
+
|
| 241 |
+
# main from here
|
| 242 |
+
log = Logger(0, ".log")
|
| 243 |
+
|
| 244 |
+
# get (default) ckpt option
|
| 245 |
+
ckpt_opt = ckpt_util.build_ckpt_option(i2sb_opt, log, RESULT_DIR / i2sb_opt.ckpt)
|
| 246 |
+
corrupt_type = ckpt_opt.corrupt
|
| 247 |
+
nfe = i2sb_opt.nfe or ckpt_opt.interval-1
|
| 248 |
+
|
| 249 |
+
# build corruption method
|
| 250 |
+
# corrupt_method = build_corruption(i2sb_opt, log, corrupt_type=cor
|
| 251 |
+
# rupt_type)
|
| 252 |
+
runner = Runner(ckpt_opt, log, save_opt=False)
|
| 253 |
+
if i2sb_opt.use_fp16:
|
| 254 |
+
runner.ema.copy_to() # copy weight from ema to net
|
| 255 |
+
runner.net.diffusion_model.convert_to_fp16()
|
| 256 |
+
runner.ema = ExponentialMovingAverage(
|
| 257 |
+
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
|
| 258 |
+
|
| 259 |
+
print("Loading time:", (time.time()-s)*1e3, "ms.")
|
| 260 |
+
i2sb_model = runner
|
| 261 |
+
return runner
|
| 262 |
+
|
| 263 |
def plot_boxes_to_image(image_pil, tgt):
|
| 264 |
H, W = tgt["size"]
|
| 265 |
boxes = tgt["boxes"]
|
|
|
|
| 325 |
image, _ = transform(image_pil, None) # 3, h, w
|
| 326 |
return image_pil, image
|
| 327 |
|
| 328 |
+
|
| 329 |
+
|
| 330 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
| 331 |
caption = caption.lower()
|
| 332 |
caption = caption.strip()
|
|
|
|
| 446 |
torch_dtype=torch.float16,
|
| 447 |
)
|
| 448 |
sd_model = sd_model.to(device)
|
| 449 |
+
|
| 450 |
+
def forward_i2sb(img, mask):
|
| 451 |
+
print(np.unique(img),mask.shape)
|
| 452 |
+
mask = np.where(mask > 0, 1, 0)
|
| 453 |
+
img_tensor = i2sb_transforms(img).to(
|
| 454 |
+
i2sb_opt.device).unsqueeze(0)
|
| 455 |
+
|
| 456 |
+
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
|
| 457 |
+
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
|
| 458 |
+
print("POST PROCESSING\t", torch.unique(img_tensor))
|
| 459 |
+
# corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
|
| 460 |
+
f = time.time()
|
| 461 |
+
xs, _ = i2sb_model.ddpm_sampling(
|
| 462 |
+
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
|
| 463 |
+
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
|
| 464 |
+
tu.save_image((recon_img+1)/2, "output.png")
|
| 465 |
+
print(recon_img.shape)
|
| 466 |
+
return transforms.ToPILImage()(((recon_img+1)/2)[0])
|
| 467 |
|
| 468 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
| 469 |
try:
|
|
|
|
| 618 |
return new_image
|
| 619 |
|
| 620 |
mask_source_draw = "draw a mask on input image"
|
| 621 |
+
mask_source_segment = "upload a mask"
|
| 622 |
|
| 623 |
def get_time_cost(run_task_time, time_cost_str):
|
| 624 |
now_time = int(time.time()*1000)
|
|
|
|
| 631 |
run_task_time = now_time
|
| 632 |
return run_task_time, time_cost_str
|
| 633 |
|
| 634 |
+
def run_anything_task(input_image, input_points, origin_image, task_type,
|
| 635 |
+
mask_source_radio, cleaner_size_limit=1080):
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
run_task_time = 0
|
| 638 |
time_cost_str = ''
|
|
|
|
| 647 |
image_pil, image = load_image(input_image.convert("RGB"))
|
| 648 |
input_img = input_image
|
| 649 |
|
| 650 |
+
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor)
|
| 651 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 652 |
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
| 653 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
if input_image is None:
|
| 655 |
return [], gr.Gallery.update(label='Please upload a image!ππππ'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 656 |
|
| 657 |
file_temp = int(time.time())
|
| 658 |
+
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
|
| 659 |
|
| 660 |
output_images = []
|
| 661 |
|
| 662 |
# load image
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
if isinstance(input_image, dict):
|
| 665 |
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
|
|
|
| 722 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 723 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 724 |
elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
|
| 725 |
+
if mask_source_radio == mask_source_segment:
|
| 726 |
task_type = 'remove'
|
| 727 |
|
| 728 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
| 729 |
if mask_source_radio == mask_source_draw:
|
| 730 |
+
input_mask_pil = input_image['mask']
|
| 731 |
+
input_mask = np.array(input_mask_pil.convert("L"))
|
| 732 |
mask_pil = input_mask_pil
|
| 733 |
mask = input_mask
|
| 734 |
else:
|
| 735 |
masks_ori = copy.deepcopy(masks)
|
|
|
|
|
|
|
| 736 |
masks = torch.where(masks > 0, True, False)
|
| 737 |
mask = masks[0][0].cpu().numpy()
|
| 738 |
mask_pil = Image.fromarray(mask)
|
|
|
|
| 740 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 741 |
|
| 742 |
if task_type in ['inpainting', 'outpainting']:
|
| 743 |
+
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 744 |
+
input_img.save("test.png")
|
| 745 |
+
image_inpainting = forward_i2sb(input_img, mask)
|
| 746 |
+
|
| 747 |
+
print("RESULT\t", np.array(image_inpainting))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
else:
|
| 749 |
# remove from mask
|
| 750 |
aasds = 1
|
|
|
|
| 770 |
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
|
| 771 |
|
| 772 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
|
|
|
|
|
|
| 773 |
mask_source_radio_visible = False
|
| 774 |
num_relation_visible = False
|
| 775 |
|
|
|
|
| 780 |
print(task_type)
|
| 781 |
if task_type == "Kosmos-2":
|
| 782 |
if kosmos_enable:
|
|
|
|
| 783 |
image_gallery_visible = False
|
| 784 |
kosmos_input_visible = True
|
| 785 |
kosmos_output_visible = True
|
| 786 |
kosmos_text_output_visible = True
|
| 787 |
|
|
|
|
|
|
|
| 788 |
if task_type in ['inpainting', 'outpainting'] or task_type == "remove":
|
| 789 |
mask_source_radio_visible = True
|
|
|
|
|
|
|
| 790 |
if task_type == "relate anything":
|
|
|
|
| 791 |
num_relation_visible = True
|
| 792 |
if task_type == "segment":
|
| 793 |
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 794 |
elif task_type == "inpainting":
|
| 795 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 796 |
|
| 797 |
+
return (gr.Radio.update(visible=mask_source_radio_visible),
|
|
|
|
|
|
|
| 798 |
gr.Slider.update(visible=num_relation_visible),
|
| 799 |
gr.Gallery.update(visible=image_gallery_visible),
|
| 800 |
gr.Radio.update(visible=kosmos_input_visible),
|
| 801 |
gr.Image.update(visible=kosmos_output_visible),
|
| 802 |
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
| 803 |
+
ret, [],
|
| 804 |
+
gr.Button("Undo point", visible = task_type == "segment"),
|
| 805 |
+
gr.Button("Clear point", visible = task_type == "segment"),)
|
| 806 |
|
| 807 |
def get_model_device(module):
|
| 808 |
try:
|
|
|
|
| 851 |
[input_image, selected_points],
|
| 852 |
[input_image]
|
| 853 |
)
|
| 854 |
+
with gr.Row():
|
| 855 |
+
with gr.Column():
|
| 856 |
+
|
| 857 |
+
undo_point_button = gr.Button("Undo point")
|
| 858 |
+
undo_point_button.click(
|
| 859 |
+
fn= undo_button,
|
| 860 |
+
inputs=[original_image, selected_points],
|
| 861 |
+
outputs=[input_image]
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
with gr.Column():
|
| 865 |
+
|
| 866 |
+
clear_point_button = gr.Button("Clear point")
|
| 867 |
+
clear_point_button.click(
|
| 868 |
+
fn= clear_button,
|
| 869 |
+
inputs=[original_image],
|
| 870 |
+
outputs=[input_image, selected_points]
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
print(dir(input_image))
|
| 874 |
task_type = gr.Radio(task_types, value="segment",
|
| 875 |
label='Task type', visible=True)
|
| 876 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 877 |
+
value=mask_source_draw, label="Mask from",
|
| 878 |
visible=False)
|
|
|
|
|
|
|
| 879 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 880 |
|
| 881 |
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
| 882 |
|
| 883 |
run_button = gr.Button(label="Run", visible=True)
|
| 884 |
+
# with gr.Accordion("Advanced options", open=False) as advanced_options:
|
| 885 |
+
# box_threshold = gr.Slider(
|
| 886 |
+
# label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
| 887 |
+
# )
|
| 888 |
+
# text_threshold = gr.Slider(
|
| 889 |
+
# label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
| 890 |
+
# )
|
| 891 |
+
# iou_threshold = gr.Slider(
|
| 892 |
+
# label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
|
| 893 |
+
# )
|
| 894 |
+
# inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
|
| 895 |
+
# with gr.Row():
|
| 896 |
+
# with gr.Column(scale=1):
|
| 897 |
+
# remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
|
| 898 |
+
# with gr.Column(scale=1):
|
| 899 |
+
# remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
| 900 |
|
| 901 |
with gr.Column():
|
| 902 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
|
|
|
| 932 |
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
| 933 |
|
| 934 |
run_button.click(fn=run_anything_task, inputs=[
|
| 935 |
+
input_image, selected_points, original_image, task_type,
|
| 936 |
+
mask_source_radio],
|
| 937 |
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
| 938 |
|
| 939 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 940 |
+
outputs=[mask_source_radio, num_relation])
|
| 941 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 942 |
+
outputs=[mask_source_radio, num_relation,
|
| 943 |
+
image_gallery, kosmos_input, kosmos_output, kosmos_text_output, input_image, selected_points, undo_point_button, clear_point_button
|
| 944 |
])
|
| 945 |
|
| 946 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
|
|
| 986 |
if sam_enable:
|
| 987 |
load_sam_model(device)
|
| 988 |
|
| 989 |
+
if inpainting_enable:
|
| 990 |
+
load_sd_model(device)
|
| 991 |
+
load_i2sb_model()
|
| 992 |
|
| 993 |
# if lama_cleaner_enable:
|
| 994 |
# load_lama_cleaner_model(device)
|