Spaces:
Runtime error
Runtime error
update demo
Browse files- .gitignore +2 -1
- .log/log.txt +5 -5
- SegFormer +1 -1
- mask.png +0 -0
- output.png +0 -0
- streamlit_test.py +3 -0
- test.png +0 -0
- test.py +168 -242
.gitignore
CHANGED
|
@@ -2,4 +2,5 @@ __pycache__
|
|
| 2 |
*.pyc
|
| 3 |
checkpoints/
|
| 4 |
I2SB/
|
| 5 |
-
*.pth
|
|
|
|
|
|
| 2 |
*.pyc
|
| 3 |
checkpoints/
|
| 4 |
I2SB/
|
| 5 |
+
*.pth
|
| 6 |
+
SegFormer/
|
.log/log.txt
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
[19:
|
| 2 |
INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
|
| 3 |
-
[19:
|
| 4 |
-
[19:02
|
| 5 |
-
[19:
|
| 6 |
-
[19:
|
|
|
|
| 1 |
+
[19:58:55] INFO (0:00:00) Loaded options from opt_pkl_path=PosixPath('I2SB/results/inpaint-freeform2030/options.pkl')!
|
| 2 |
INFO (0:00:00) [Diffusion] Built I2SB diffusion: steps=1000!
|
| 3 |
+
[19:58:58] INFO (0:00:03) [Net] Initialized network from ckpt_pkl='I2SB/data/256x256_diffusion_uncond_fixedsigma.pkl'! Size=552807171!
|
| 4 |
+
[19:59:02] INFO (0:00:07) [Net] Loaded pretrained adm ckpt_pt='I2SB/data/256x256_diffusion_uncond_fixedsigma.pt'!
|
| 5 |
+
[19:59:06] INFO (0:00:11) [Net] Loaded network ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
| 6 |
+
[19:59:08] INFO (0:00:13) [Ema] Loaded ema ckpt: I2SB/results/inpaint-freeform2030/latest.pt!
|
SegFormer
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
Subproject commit
|
|
|
|
| 1 |
+
Subproject commit ccc3dd500c4091a583b4b2749e35da501e670aca
|
mask.png
ADDED
|
output.png
CHANGED
|
|
streamlit_test.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
st.write("Hello")
|
test.png
CHANGED
|
|
test.py
CHANGED
|
@@ -40,6 +40,7 @@ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases
|
|
| 40 |
import sys
|
| 41 |
|
| 42 |
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
|
|
|
|
| 43 |
|
| 44 |
import numpy as np
|
| 45 |
import torch
|
|
@@ -62,6 +63,18 @@ from I2SB.i2sb import Runner, ckpt_util, download_ckpt
|
|
| 62 |
from I2SB.logger import Logger
|
| 63 |
from I2SB.sample import *
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
import cv2
|
|
@@ -89,13 +102,6 @@ if os.environ.get('IS_MY_DEBUG') is not None:
|
|
| 89 |
inpainting_enable = False
|
| 90 |
kosmos_enable = False
|
| 91 |
|
| 92 |
-
if lama_cleaner_enable:
|
| 93 |
-
try:
|
| 94 |
-
from lama_cleaner.model_manager import ModelManager
|
| 95 |
-
from lama_cleaner.schema import Config as lama_Config
|
| 96 |
-
except Exception as e:
|
| 97 |
-
lama_cleaner_enable = False
|
| 98 |
-
|
| 99 |
# segment anything
|
| 100 |
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
| 101 |
|
|
@@ -191,13 +197,16 @@ def get_point(img, sel_pix, evt: gr.SelectData):
|
|
| 191 |
|
| 192 |
|
| 193 |
def undo_button(orig_img, sel_pix):
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
sel_pix
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
def clear_button(orig_img):
|
| 203 |
|
|
@@ -256,10 +265,22 @@ def load_i2sb_model():
|
|
| 256 |
runner.ema = ExponentialMovingAverage(
|
| 257 |
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
|
| 258 |
|
|
|
|
| 259 |
print("Loading time:", (time.time()-s)*1e3, "ms.")
|
| 260 |
i2sb_model = runner
|
| 261 |
return runner
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
def plot_boxes_to_image(image_pil, tgt):
|
| 264 |
H, W = tgt["size"]
|
| 265 |
boxes = tgt["boxes"]
|
|
@@ -326,42 +347,6 @@ def load_image(image_path):
|
|
| 326 |
return image_pil, image
|
| 327 |
|
| 328 |
|
| 329 |
-
|
| 330 |
-
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
| 331 |
-
caption = caption.lower()
|
| 332 |
-
caption = caption.strip()
|
| 333 |
-
if not caption.endswith("."):
|
| 334 |
-
caption = caption + "."
|
| 335 |
-
model = model.to(device)
|
| 336 |
-
image = image.to(device)
|
| 337 |
-
with torch.no_grad():
|
| 338 |
-
outputs = model(image[None], captions=[caption])
|
| 339 |
-
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
| 340 |
-
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
|
| 341 |
-
logits.shape[0]
|
| 342 |
-
|
| 343 |
-
# filter output
|
| 344 |
-
logits_filt = logits.clone()
|
| 345 |
-
boxes_filt = boxes.clone()
|
| 346 |
-
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
|
| 347 |
-
logits_filt = logits_filt[filt_mask] # num_filt, 256
|
| 348 |
-
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
|
| 349 |
-
logits_filt.shape[0]
|
| 350 |
-
|
| 351 |
-
# get phrase
|
| 352 |
-
tokenlizer = model.tokenizer
|
| 353 |
-
tokenized = tokenlizer(caption)
|
| 354 |
-
# build pred
|
| 355 |
-
pred_phrases = []
|
| 356 |
-
for logit, box in zip(logits_filt, boxes_filt):
|
| 357 |
-
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
|
| 358 |
-
if with_logits:
|
| 359 |
-
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
|
| 360 |
-
else:
|
| 361 |
-
pred_phrases.append(pred_phrase)
|
| 362 |
-
|
| 363 |
-
return boxes_filt, pred_phrases
|
| 364 |
-
|
| 365 |
def show_mask(mask, ax, random_color=False):
|
| 366 |
if random_color:
|
| 367 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
@@ -447,99 +432,45 @@ def load_sd_model(device):
|
|
| 447 |
)
|
| 448 |
sd_model = sd_model.to(device)
|
| 449 |
|
| 450 |
-
def forward_i2sb(img, mask):
|
| 451 |
-
|
|
|
|
|
|
|
| 452 |
mask = np.where(mask > 0, 1, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
img_tensor = i2sb_transforms(img).to(
|
| 454 |
i2sb_opt.device).unsqueeze(0)
|
| 455 |
|
| 456 |
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
|
| 457 |
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
|
| 458 |
-
print("POST PROCESSING\t", torch.unique(img_tensor))
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
| 460 |
f = time.time()
|
| 461 |
xs, _ = i2sb_model.ddpm_sampling(
|
| 462 |
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
|
| 463 |
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
|
| 464 |
-
tu.save_image((recon_img+1)/2, "output.png")
|
|
|
|
| 465 |
print(recon_img.shape)
|
| 466 |
-
return transforms.ToPILImage()(((recon_img+1)/2)[0])
|
| 467 |
|
| 468 |
-
def
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
ori_image = image
|
| 472 |
-
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
| 473 |
-
# rotate image
|
| 474 |
-
logger.info(f'_______lama_cleaner_process_______2____')
|
| 475 |
-
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
| 476 |
-
logger.info(f'_______lama_cleaner_process_______3____')
|
| 477 |
-
image = ori_image
|
| 478 |
-
|
| 479 |
-
logger.info(f'_______lama_cleaner_process_______4____')
|
| 480 |
-
original_shape = ori_image.shape
|
| 481 |
-
logger.info(f'_______lama_cleaner_process_______5____')
|
| 482 |
-
interpolation = cv2.INTER_CUBIC
|
| 483 |
-
|
| 484 |
-
size_limit = cleaner_size_limit
|
| 485 |
-
if size_limit == -1:
|
| 486 |
-
logger.info(f'_______lama_cleaner_process_______6____')
|
| 487 |
-
size_limit = max(image.shape)
|
| 488 |
-
else:
|
| 489 |
-
logger.info(f'_______lama_cleaner_process_______7____')
|
| 490 |
-
size_limit = int(size_limit)
|
| 491 |
-
|
| 492 |
-
logger.info(f'_______lama_cleaner_process_______8____')
|
| 493 |
-
config = lama_Config(
|
| 494 |
-
ldm_steps=25,
|
| 495 |
-
ldm_sampler='plms',
|
| 496 |
-
zits_wireframe=True,
|
| 497 |
-
hd_strategy='Original',
|
| 498 |
-
hd_strategy_crop_margin=196,
|
| 499 |
-
hd_strategy_crop_trigger_size=1280,
|
| 500 |
-
hd_strategy_resize_limit=2048,
|
| 501 |
-
prompt='',
|
| 502 |
-
use_croper=False,
|
| 503 |
-
croper_x=0,
|
| 504 |
-
croper_y=0,
|
| 505 |
-
croper_height=512,
|
| 506 |
-
croper_width=512,
|
| 507 |
-
sd_mask_blur=5,
|
| 508 |
-
sd_strength=0.75,
|
| 509 |
-
sd_steps=50,
|
| 510 |
-
sd_guidance_scale=7.5,
|
| 511 |
-
sd_sampler='ddim',
|
| 512 |
-
sd_seed=42,
|
| 513 |
-
cv2_flag='INPAINT_NS',
|
| 514 |
-
cv2_radius=5,
|
| 515 |
-
)
|
| 516 |
-
|
| 517 |
-
logger.info(f'_______lama_cleaner_process_______9____')
|
| 518 |
-
if config.sd_seed == -1:
|
| 519 |
-
config.sd_seed = random.randint(1, 999999999)
|
| 520 |
-
|
| 521 |
-
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
| 522 |
-
logger.info(f'_______lama_cleaner_process_______10____')
|
| 523 |
-
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
| 524 |
-
# logger.info(f"Resized image shape_1_: {image.shape}")
|
| 525 |
-
|
| 526 |
-
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
| 527 |
-
logger.info(f'_______lama_cleaner_process_______11____')
|
| 528 |
-
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 529 |
-
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
| 530 |
-
|
| 531 |
-
logger.info(f'_______lama_cleaner_process_______12____')
|
| 532 |
-
res_np_img = lama_cleaner_model(image, mask, config)
|
| 533 |
-
logger.info(f'_______lama_cleaner_process_______13____')
|
| 534 |
-
torch.cuda.empty_cache()
|
| 535 |
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
except Exception as e:
|
| 540 |
-
logger.info(f'lama_cleaner_process[Error]:' + str(e))
|
| 541 |
-
image = None
|
| 542 |
-
return image
|
| 543 |
|
| 544 |
# visualization
|
| 545 |
def draw_selected_mask(mask, draw):
|
|
@@ -632,27 +563,15 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
| 632 |
return run_task_time, time_cost_str
|
| 633 |
|
| 634 |
def run_anything_task(input_image, input_points, origin_image, task_type,
|
| 635 |
-
mask_source_radio,
|
| 636 |
|
| 637 |
run_task_time = 0
|
| 638 |
time_cost_str = ''
|
| 639 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 640 |
print("HERE................", task_type)
|
| 641 |
-
|
| 642 |
-
global kosmos_model, kosmos_processor
|
| 643 |
-
if isinstance(input_image, dict):
|
| 644 |
-
image_pil, image = load_image(input_image['image'].convert("RGB"))
|
| 645 |
-
input_img = input_image['image']
|
| 646 |
-
else:
|
| 647 |
-
image_pil, image = load_image(input_image.convert("RGB"))
|
| 648 |
-
input_img = input_image
|
| 649 |
-
|
| 650 |
-
kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_model, kosmos_processor)
|
| 651 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 652 |
-
return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
|
| 653 |
-
|
| 654 |
if input_image is None:
|
| 655 |
-
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 656 |
|
| 657 |
file_temp = int(time.time())
|
| 658 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
|
|
@@ -682,92 +601,119 @@ def run_anything_task(input_image, input_points, origin_image, task_type,
|
|
| 682 |
groundingdino_device = 'cpu'
|
| 683 |
|
| 684 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
| 685 |
-
if task_type == 'segment' or
|
| 686 |
-
image = np.array(
|
| 687 |
-
if
|
| 688 |
-
sam_predictor
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
else:
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
plt.imshow(origin_image)
|
| 706 |
-
for mask in masks:
|
| 707 |
-
show_mask(mask, plt.gca(), random_color=True)
|
| 708 |
-
# for box, label in zip(boxes_filt, pred_phrases):
|
| 709 |
-
# show_box(box.cpu().numpy(), plt.gca(), label)
|
| 710 |
-
plt.axis('off')
|
| 711 |
-
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
| 712 |
-
plt.savefig(image_path, bbox_inches="tight")
|
| 713 |
-
plt.clf()
|
| 714 |
-
plt.close('all')
|
| 715 |
-
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 716 |
-
os.remove(image_path)
|
| 717 |
output_images.append(Image.fromarray(segment_image_result))
|
| 718 |
-
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
|
|
|
| 719 |
|
| 720 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
| 721 |
if task_type == 'detection' or task_type == 'segment':
|
| 722 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 723 |
-
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 724 |
-
elif task_type in ['inpainting', 'outpainting'] or task_type == '
|
| 725 |
-
if mask_source_radio == mask_source_segment:
|
| 726 |
-
task_type = 'remove'
|
| 727 |
|
| 728 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
| 729 |
-
if
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
else:
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
output_images.append(mask_pil.convert("RGB"))
|
| 740 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 741 |
|
| 742 |
-
if task_type in ['inpainting', '
|
| 743 |
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 744 |
-
input_img.save("test.png")
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
else:
|
| 749 |
# remove from mask
|
| 750 |
aasds = 1
|
| 751 |
|
| 752 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
| 753 |
-
image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
|
| 754 |
if image_inpainting is None:
|
| 755 |
logger.info(f'run_anything_task_failed_')
|
| 756 |
-
return None, None, None, None
|
| 757 |
|
| 758 |
# output_images.append(image_inpainting)
|
| 759 |
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 760 |
|
| 761 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
| 762 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
|
|
|
| 763 |
output_images.append(image_inpainting)
|
| 764 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 765 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 766 |
-
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 767 |
else:
|
| 768 |
logger.info(f"task_type:{task_type} error!")
|
| 769 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
| 770 |
-
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 771 |
|
| 772 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
| 773 |
mask_source_radio_visible = False
|
|
@@ -789,20 +735,19 @@ def change_radio_display(task_type, mask_source_radio, orig_img):
|
|
| 789 |
mask_source_radio_visible = True
|
| 790 |
if task_type == "relate anything":
|
| 791 |
num_relation_visible = True
|
| 792 |
-
if task_type == "
|
| 793 |
-
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 794 |
-
elif task_type == "inpainting":
|
| 795 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
|
|
|
|
|
|
| 796 |
|
| 797 |
return (gr.Radio.update(visible=mask_source_radio_visible),
|
| 798 |
gr.Slider.update(visible=num_relation_visible),
|
| 799 |
gr.Gallery.update(visible=image_gallery_visible),
|
| 800 |
-
gr.Radio
|
| 801 |
-
gr.
|
| 802 |
-
gr.HighlightedText.update(visible=kosmos_text_output_visible),
|
| 803 |
ret, [],
|
| 804 |
-
gr.Button("Undo point", visible = task_type
|
| 805 |
-
gr.Button("Clear point", visible = task_type
|
| 806 |
|
| 807 |
def get_model_device(module):
|
| 808 |
try:
|
|
@@ -832,10 +777,11 @@ def main_gradio(args):
|
|
| 832 |
with gr.Row():
|
| 833 |
with gr.Column():
|
| 834 |
selected_points = gr.State([])
|
| 835 |
-
original_image = gr.State()
|
| 836 |
task_types = ["segment"]
|
| 837 |
if inpainting_enable:
|
| 838 |
task_types.append("inpainting")
|
|
|
|
| 839 |
|
| 840 |
|
| 841 |
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
|
|
@@ -854,7 +800,7 @@ def main_gradio(args):
|
|
| 854 |
with gr.Row():
|
| 855 |
with gr.Column():
|
| 856 |
|
| 857 |
-
undo_point_button = gr.Button("Undo point")
|
| 858 |
undo_point_button.click(
|
| 859 |
fn= undo_button,
|
| 860 |
inputs=[original_image, selected_points],
|
|
@@ -863,7 +809,7 @@ def main_gradio(args):
|
|
| 863 |
|
| 864 |
with gr.Column():
|
| 865 |
|
| 866 |
-
clear_point_button = gr.Button("Clear point")
|
| 867 |
clear_point_button.click(
|
| 868 |
fn= clear_button,
|
| 869 |
inputs=[original_image],
|
|
@@ -876,10 +822,15 @@ def main_gradio(args):
|
|
| 876 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 877 |
value=mask_source_draw, label="Mask from",
|
| 878 |
visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 880 |
|
| 881 |
-
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
| 882 |
-
|
| 883 |
run_button = gr.Button(label="Run", visible=True)
|
| 884 |
# with gr.Accordion("Advanced options", open=False) as advanced_options:
|
| 885 |
# box_threshold = gr.Slider(
|
|
@@ -900,47 +851,21 @@ def main_gradio(args):
|
|
| 900 |
|
| 901 |
with gr.Column():
|
| 902 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
| 903 |
-
).style(preview=True, columns=[5], object_fit="scale-down", height=
|
| 904 |
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
| 905 |
|
| 906 |
-
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
| 907 |
-
kosmos_text_output = gr.HighlightedText(
|
| 908 |
-
label="Generated Description",
|
| 909 |
-
combine_adjacent=False,
|
| 910 |
-
show_legend=True,
|
| 911 |
-
visible=False,
|
| 912 |
-
).style(color_map=color_map)
|
| 913 |
-
# record which text span (label) is selected
|
| 914 |
-
selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
|
| 915 |
-
|
| 916 |
-
# record the current `entities`
|
| 917 |
-
entity_output = gr.Textbox(visible=False)
|
| 918 |
-
|
| 919 |
-
# get the current selected span label
|
| 920 |
-
def get_text_span_label(evt: gr.SelectData):
|
| 921 |
-
if evt.value[-1] is None:
|
| 922 |
-
return -1
|
| 923 |
-
return int(evt.value[-1])
|
| 924 |
-
# and set this information to `selected`
|
| 925 |
-
kosmos_text_output.select(get_text_span_label, None, selected)
|
| 926 |
|
| 927 |
-
# update output image when we change the span (enity) selection
|
| 928 |
-
def update_output_image(img_input, image_output, entities, idx):
|
| 929 |
-
entities = ast.literal_eval(entities)
|
| 930 |
-
updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
|
| 931 |
-
return updated_image
|
| 932 |
-
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
| 933 |
|
| 934 |
run_button.click(fn=run_anything_task, inputs=[
|
| 935 |
input_image, selected_points, original_image, task_type,
|
| 936 |
-
mask_source_radio],
|
| 937 |
-
outputs=[image_gallery, image_gallery, time_cost, time_cost
|
| 938 |
|
| 939 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 940 |
outputs=[mask_source_radio, num_relation])
|
| 941 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 942 |
outputs=[mask_source_radio, num_relation,
|
| 943 |
-
image_gallery,
|
| 944 |
])
|
| 945 |
|
| 946 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
@@ -985,6 +910,7 @@ if __name__ == "__main__":
|
|
| 985 |
|
| 986 |
if sam_enable:
|
| 987 |
load_sam_model(device)
|
|
|
|
| 988 |
|
| 989 |
if inpainting_enable:
|
| 990 |
load_sd_model(device)
|
|
|
|
| 40 |
import sys
|
| 41 |
|
| 42 |
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/I2SB")
|
| 43 |
+
sys.path.insert(0, "/home/ubuntu/Thesis-Demo/SegFormer")
|
| 44 |
|
| 45 |
import numpy as np
|
| 46 |
import torch
|
|
|
|
| 63 |
from I2SB.logger import Logger
|
| 64 |
from I2SB.sample import *
|
| 65 |
|
| 66 |
+
from pathlib import Path
|
| 67 |
+
|
| 68 |
+
inpaint_checkpoint = Path("/home/ubuntu/Thesis-Demo/I2SB/results")
|
| 69 |
+
|
| 70 |
+
if not inpaint_checkpoint.exists():
|
| 71 |
+
os.system("pip install transformers==4.32.0")
|
| 72 |
+
|
| 73 |
+
# SegFormer
|
| 74 |
+
from PIL import Image
|
| 75 |
+
|
| 76 |
+
from SegFormer.mmseg.apis import inference_segmentor, init_segmentor, visualize_result_pyplot
|
| 77 |
+
from SegFormer.mmseg.core.evaluation import get_palette
|
| 78 |
|
| 79 |
|
| 80 |
import cv2
|
|
|
|
| 102 |
inpainting_enable = False
|
| 103 |
kosmos_enable = False
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
# segment anything
|
| 106 |
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
| 107 |
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
def undo_button(orig_img, sel_pix):
|
| 200 |
+
if orig_img:
|
| 201 |
+
temp = orig_img.copy()
|
| 202 |
+
temp = np.array(temp, dtype=np.uint8)
|
| 203 |
+
if len(sel_pix) != 0:
|
| 204 |
+
sel_pix.pop()
|
| 205 |
+
for point in sel_pix:
|
| 206 |
+
cv2.drawMarker(temp, point, colors[0], markerType=markers[0], markerSize=6, thickness=2)
|
| 207 |
+
return Image.fromarray(temp).convert("RGB")
|
| 208 |
+
return orig_img
|
| 209 |
+
|
| 210 |
|
| 211 |
def clear_button(orig_img):
|
| 212 |
|
|
|
|
| 265 |
runner.ema = ExponentialMovingAverage(
|
| 266 |
runner.net.parameters(), decay=0.99) # re-init ema with fp16 weight
|
| 267 |
|
| 268 |
+
logger.info(f"I2SB Loading time:\t {(time.time()-s)*1e3} ms.")
|
| 269 |
print("Loading time:", (time.time()-s)*1e3, "ms.")
|
| 270 |
i2sb_model = runner
|
| 271 |
return runner
|
| 272 |
|
| 273 |
+
def load_segformer(device):
|
| 274 |
+
global segformer_model
|
| 275 |
+
s = time.time()
|
| 276 |
+
config = "SegFormer/local_configs/segformer/B3/segformer.b3.256x256.wtm.160k.py"
|
| 277 |
+
checkpoint = "SegFormer/work_dirs/segformer.b3.256x256.wtm.160k/iter_160000.pth"
|
| 278 |
+
model = init_segmentor(config, checkpoint, device=device)
|
| 279 |
+
|
| 280 |
+
logger.info(f"SegFormer Loading time:\t {(time.time()-s)*1e3} ms.")
|
| 281 |
+
segformer_model = model
|
| 282 |
+
return model
|
| 283 |
+
|
| 284 |
def plot_boxes_to_image(image_pil, tgt):
|
| 285 |
H, W = tgt["size"]
|
| 286 |
boxes = tgt["boxes"]
|
|
|
|
| 347 |
return image_pil, image
|
| 348 |
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
def show_mask(mask, ax, random_color=False):
|
| 351 |
if random_color:
|
| 352 |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
|
|
|
| 432 |
)
|
| 433 |
sd_model = sd_model.to(device)
|
| 434 |
|
| 435 |
+
def forward_i2sb(img, mask, dilation_mask_extend):
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
print(np.unique(mask),mask.shape)
|
| 439 |
mask = np.where(mask > 0, 1, 0)
|
| 440 |
+
print(np.unique(mask),mask.shape)
|
| 441 |
+
mask = mask.astype(np.uint8)
|
| 442 |
+
if dilation_mask_extend.isdigit():
|
| 443 |
+
|
| 444 |
+
kernel_size = int(dilation_mask_extend)
|
| 445 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (int(kernel_size), int(kernel_size)))
|
| 446 |
+
mask = cv2.dilate(mask, kernel, iterations = 1)
|
| 447 |
+
|
| 448 |
img_tensor = i2sb_transforms(img).to(
|
| 449 |
i2sb_opt.device).unsqueeze(0)
|
| 450 |
|
| 451 |
mask_tensor = torch.from_numpy(np.resize(np.array(mask), (256,256))).to(
|
| 452 |
i2sb_opt.device).unsqueeze(0).unsqueeze(0)
|
| 453 |
+
# print("POST PROCESSING\t", torch.unique(img_tensor))
|
| 454 |
+
corrupt_tensor = img_tensor * (1. - mask_tensor) + mask_tensor
|
| 455 |
+
print("DOUBLE CHECK:\t", corrupt_tensor.shape)
|
| 456 |
+
print("DOUBLE CHECK:\t", img_tensor.shape)
|
| 457 |
+
print("DOUBLE CHECK:\t", mask_tensor.shape)
|
| 458 |
f = time.time()
|
| 459 |
xs, _ = i2sb_model.ddpm_sampling(
|
| 460 |
ckpt_opt, img_tensor, mask=mask_tensor, cond=None, clip_denoise=i2sb_opt.clip_denoise, nfe=nfe, verbose=i2sb_opt.n_gpu_per_node == 1)
|
| 461 |
recon_img = xs[:, 0, ...].to(i2sb_opt.device)
|
| 462 |
+
# tu.save_image((recon_img+1)/2, "output.png")
|
| 463 |
+
# tu.save_image((corrupt_tensor+1)/2, "output.png")
|
| 464 |
print(recon_img.shape)
|
| 465 |
+
return transforms.ToPILImage()(((recon_img+1)/2)[0]), transforms.ToPILImage()(((corrupt_tensor+1)/2)[0])
|
| 466 |
|
| 467 |
+
def forward_segformer(img):
|
| 468 |
+
img_np = np.array(img)
|
| 469 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
|
| 471 |
+
result = inference_segmentor(segformer_model, img_np)
|
| 472 |
+
|
| 473 |
+
return np.asarray(result[0], dtype=np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
# visualization
|
| 476 |
def draw_selected_mask(mask, draw):
|
|
|
|
| 563 |
return run_task_time, time_cost_str
|
| 564 |
|
| 565 |
def run_anything_task(input_image, input_points, origin_image, task_type,
|
| 566 |
+
mask_source_radio, segmentation_radio, dilation_mask_extend):
|
| 567 |
|
| 568 |
run_task_time = 0
|
| 569 |
time_cost_str = ''
|
| 570 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 571 |
print("HERE................", task_type)
|
| 572 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
if input_image is None:
|
| 574 |
+
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 575 |
|
| 576 |
file_temp = int(time.time())
|
| 577 |
logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/[{mask_source_radio}]_1_')
|
|
|
|
| 601 |
groundingdino_device = 'cpu'
|
| 602 |
|
| 603 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
| 604 |
+
if task_type == 'segment' or task_type == 'pipeline':
|
| 605 |
+
image = np.array(origin_image)
|
| 606 |
+
if segmentation_radio == "SAM":
|
| 607 |
+
if sam_predictor:
|
| 608 |
+
sam_predictor.set_image(image)
|
| 609 |
+
|
| 610 |
+
if sam_predictor:
|
| 611 |
+
logger.info(f"Forward with: {input_points}")
|
| 612 |
+
masks, _, _, _ = sam_predictor.predict(
|
| 613 |
+
point_coords = np.array(input_points),
|
| 614 |
+
point_labels = np.array([1 for _ in range(len(input_points))]),
|
| 615 |
+
# boxes = transformed_boxes,
|
| 616 |
+
multimask_output = False,
|
| 617 |
+
)
|
| 618 |
+
# masks: [9, 1, 512, 512]
|
| 619 |
+
assert sam_checkpoint, 'sam_checkpoint is not found!'
|
| 620 |
+
else:
|
| 621 |
+
run_mode = "rectangle"
|
| 622 |
+
|
| 623 |
+
# draw output image
|
| 624 |
+
plt.figure(figsize=(10, 10))
|
| 625 |
+
plt.imshow(origin_image)
|
| 626 |
+
for mask in masks:
|
| 627 |
+
show_mask(mask, plt.gca(), random_color=True)
|
| 628 |
+
# for box, label in zip(boxes_filt, pred_phrases):
|
| 629 |
+
# show_box(box.cpu().numpy(), plt.gca(), label)
|
| 630 |
+
plt.axis('off')
|
| 631 |
+
image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
|
| 632 |
+
plt.savefig(image_path, bbox_inches="tight")
|
| 633 |
+
plt.clf()
|
| 634 |
+
plt.close('all')
|
| 635 |
+
segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 636 |
+
os.remove(image_path)
|
| 637 |
+
|
| 638 |
else:
|
| 639 |
+
masks = forward_segformer(image)
|
| 640 |
+
|
| 641 |
+
segment_image_result = visualize_result_pyplot(segformer_model, image, masks, get_palette("wtm"), dilation=dilation_mask_extend)# if task_type == "pipeline" else None)
|
| 642 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
output_images.append(Image.fromarray(segment_image_result))
|
| 644 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 645 |
+
|
| 646 |
|
| 647 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
| 648 |
if task_type == 'detection' or task_type == 'segment':
|
| 649 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 650 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 651 |
+
elif task_type in ['inpainting', 'outpainting'] or task_type == 'pipeline':
|
|
|
|
|
|
|
| 652 |
|
| 653 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
| 654 |
+
if task_type == "pipeline":
|
| 655 |
+
if segmentation_radio == "SAM":
|
| 656 |
+
masks_ori = copy.deepcopy(masks)
|
| 657 |
+
print(masks.shape)
|
| 658 |
+
# masks = torch.where(masks > 0, True, False)
|
| 659 |
+
mask = masks[0]
|
| 660 |
+
mask_pil = Image.fromarray(mask)
|
| 661 |
+
mask = np.where(mask == True, 1, 0)
|
| 662 |
+
else:
|
| 663 |
+
mask = masks
|
| 664 |
+
save_mask = copy.deepcopy(mask)
|
| 665 |
+
save_mask = np.where(mask > 0, 255, 0).astype(np.uint8)
|
| 666 |
+
print((save_mask.dtype))
|
| 667 |
+
mask_pil = Image.fromarray(save_mask)
|
| 668 |
+
|
| 669 |
else:
|
| 670 |
+
if mask_source_radio == mask_source_draw:
|
| 671 |
+
input_mask_pil = input_image['mask']
|
| 672 |
+
input_mask = np.array(input_mask_pil.convert("L"))
|
| 673 |
+
mask_pil = input_mask_pil
|
| 674 |
+
mask = input_mask
|
| 675 |
+
else:
|
| 676 |
+
pass
|
| 677 |
+
# masks_ori = copy.deepcopy(masks)
|
| 678 |
+
# masks = torch.where(masks > 0, True, False)
|
| 679 |
+
# mask = masks[0][0].cpu().numpy()
|
| 680 |
+
# mask_pil = Image.fromarray(mask)
|
| 681 |
output_images.append(mask_pil.convert("RGB"))
|
| 682 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 683 |
|
| 684 |
+
if task_type in ['inpainting', 'pipeline']:
|
| 685 |
# image_inpainting = sd_model(prompt = "", image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
| 686 |
+
# input_img.save("test.png")
|
| 687 |
+
w, h = input_img.size
|
| 688 |
+
input_img = input_img.resize((256,256))
|
| 689 |
+
image_inpainting, corrupted = forward_i2sb(input_img, mask, dilation_mask_extend)
|
| 690 |
+
input_img = input_img.resize((w,h))
|
| 691 |
+
corrupted = corrupted.resize((w,h))
|
| 692 |
+
image_inpainting = image_inpainting.resize((w,h))
|
| 693 |
+
# print("RESULT\t", np.array(image_inpainting))
|
| 694 |
else:
|
| 695 |
# remove from mask
|
| 696 |
aasds = 1
|
| 697 |
|
| 698 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
|
|
|
|
| 699 |
if image_inpainting is None:
|
| 700 |
logger.info(f'run_anything_task_failed_')
|
| 701 |
+
return None, None, None, None
|
| 702 |
|
| 703 |
# output_images.append(image_inpainting)
|
| 704 |
# run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 705 |
|
| 706 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
|
| 707 |
image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
|
| 708 |
+
output_images.append(corrupted)
|
| 709 |
output_images.append(image_inpainting)
|
| 710 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
| 711 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 712 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 713 |
else:
|
| 714 |
logger.info(f"task_type:{task_type} error!")
|
| 715 |
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
| 716 |
+
return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !=''))
|
| 717 |
|
| 718 |
def change_radio_display(task_type, mask_source_radio, orig_img):
|
| 719 |
mask_source_radio_visible = False
|
|
|
|
| 735 |
mask_source_radio_visible = True
|
| 736 |
if task_type == "relate anything":
|
| 737 |
num_relation_visible = True
|
| 738 |
+
if task_type == "inpainting":
|
|
|
|
|
|
|
| 739 |
ret = gr.Image(value = orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 740 |
+
elif task_type in ["segment", "pipeline"]:
|
| 741 |
+
ret = gr.Image(value= orig_img, elem_id="image_upload", type='pil', label="Upload", height=512, tool = "editor")# tool = "sketch", brush_color='#00FFFF', mask_opacity=0.6)
|
| 742 |
|
| 743 |
return (gr.Radio.update(visible=mask_source_radio_visible),
|
| 744 |
gr.Slider.update(visible=num_relation_visible),
|
| 745 |
gr.Gallery.update(visible=image_gallery_visible),
|
| 746 |
+
gr.Radio(["SegFormer", "SAM"], value="SAM", label="Segementation Model", visible= task_type != "inpainting"),
|
| 747 |
+
gr.Textbox(label="Dilation kernel size", value='7', visible= task_type == "pipeline"),
|
|
|
|
| 748 |
ret, [],
|
| 749 |
+
gr.Button("Undo point", visible = task_type != "inpainting"),
|
| 750 |
+
gr.Button("Clear point", visible = task_type != "inpainting"),)
|
| 751 |
|
| 752 |
def get_model_device(module):
|
| 753 |
try:
|
|
|
|
| 777 |
with gr.Row():
|
| 778 |
with gr.Column():
|
| 779 |
selected_points = gr.State([])
|
| 780 |
+
original_image = gr.State(None)
|
| 781 |
task_types = ["segment"]
|
| 782 |
if inpainting_enable:
|
| 783 |
task_types.append("inpainting")
|
| 784 |
+
task_types.append("pipeline")
|
| 785 |
|
| 786 |
|
| 787 |
input_image = gr.Image(elem_id="image_upload", type='pil', label="Upload", height=512)
|
|
|
|
| 800 |
with gr.Row():
|
| 801 |
with gr.Column():
|
| 802 |
|
| 803 |
+
undo_point_button = gr.Button("Undo point", visible= True if original_image is not None else False)
|
| 804 |
undo_point_button.click(
|
| 805 |
fn= undo_button,
|
| 806 |
inputs=[original_image, selected_points],
|
|
|
|
| 809 |
|
| 810 |
with gr.Column():
|
| 811 |
|
| 812 |
+
clear_point_button = gr.Button("Clear point", visible= True if original_image is not None else False)
|
| 813 |
clear_point_button.click(
|
| 814 |
fn= clear_button,
|
| 815 |
inputs=[original_image],
|
|
|
|
| 822 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 823 |
value=mask_source_draw, label="Mask from",
|
| 824 |
visible=False)
|
| 825 |
+
|
| 826 |
+
segmentation_radio = gr.Radio(["SegFormer", "SAM"],
|
| 827 |
+
value="SAM", label="Segementation Model",
|
| 828 |
+
visible=True)
|
| 829 |
+
|
| 830 |
+
dilation_mask_extend = gr.Textbox(label="Dilation kernel size", value='5', visible=False)
|
| 831 |
+
|
| 832 |
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 833 |
|
|
|
|
|
|
|
| 834 |
run_button = gr.Button(label="Run", visible=True)
|
| 835 |
# with gr.Accordion("Advanced options", open=False) as advanced_options:
|
| 836 |
# box_threshold = gr.Slider(
|
|
|
|
| 851 |
|
| 852 |
with gr.Column():
|
| 853 |
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
| 854 |
+
).style(preview=True, columns=[5], object_fit="scale-down", height=512)
|
| 855 |
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
| 856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
run_button.click(fn=run_anything_task, inputs=[
|
| 860 |
input_image, selected_points, original_image, task_type,
|
| 861 |
+
mask_source_radio, segmentation_radio, dilation_mask_extend],
|
| 862 |
+
outputs=[image_gallery, image_gallery, time_cost, time_cost], show_progress=True, queue=True)
|
| 863 |
|
| 864 |
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 865 |
outputs=[mask_source_radio, num_relation])
|
| 866 |
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio, original_image],
|
| 867 |
outputs=[mask_source_radio, num_relation,
|
| 868 |
+
image_gallery, segmentation_radio, dilation_mask_extend, input_image, selected_points, undo_point_button, clear_point_button
|
| 869 |
])
|
| 870 |
|
| 871 |
# DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
|
|
|
| 910 |
|
| 911 |
if sam_enable:
|
| 912 |
load_sam_model(device)
|
| 913 |
+
load_segformer(device)
|
| 914 |
|
| 915 |
if inpainting_enable:
|
| 916 |
load_sd_model(device)
|