Spaces:
Runtime error
Runtime error
liuyizhang
commited on
Commit
·
7a7f9d8
1
Parent(s):
5d0da89
add ram
Browse files- app.py +200 -25
- assets/OpenSans-Bold.ttf +0 -0
- checkpoints/ram_epoch12.pth +3 -0
- ram_train_eval.py +416 -0
- ram_utils.py +152 -0
- requirements.txt +2 -7
app.py
CHANGED
|
@@ -44,7 +44,7 @@ from lama_cleaner.model_manager import ModelManager
|
|
| 44 |
from lama_cleaner.schema import Config
|
| 45 |
|
| 46 |
# segment anything
|
| 47 |
-
from segment_anything import build_sam, SamPredictor
|
| 48 |
|
| 49 |
# diffusers
|
| 50 |
import PIL
|
|
@@ -238,6 +238,7 @@ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
|
| 238 |
# initialize SAM
|
| 239 |
logger.info(f"initialize SAM model...")
|
| 240 |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
|
|
|
| 241 |
|
| 242 |
# initialize stable-diffusion-inpainting
|
| 243 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
|
@@ -319,11 +320,168 @@ def lama_cleaner_process(image, mask):
|
|
| 319 |
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
| 320 |
return image
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
mask_source_draw = "draw a mask on input image"
|
| 323 |
mask_source_segment = "type what to detect below"
|
| 324 |
|
| 325 |
-
def
|
| 326 |
-
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend):
|
|
|
|
|
|
|
|
|
|
| 327 |
text_prompt = text_prompt.strip()
|
| 328 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
| 329 |
if text_prompt == '':
|
|
@@ -333,7 +491,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 333 |
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂')
|
| 334 |
|
| 335 |
file_temp = int(time.time())
|
| 336 |
-
logger.info(f'
|
| 337 |
|
| 338 |
# load image
|
| 339 |
input_mask_pil = input_image['mask']
|
|
@@ -364,7 +522,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 364 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
| 365 |
)
|
| 366 |
if boxes_filt.size(0) == 0:
|
| 367 |
-
logger.info(f'
|
| 368 |
return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂')
|
| 369 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 370 |
|
|
@@ -380,7 +538,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 380 |
os.remove(image_path)
|
| 381 |
output_images.append(detection_image_result)
|
| 382 |
|
| 383 |
-
logger.info(f'
|
| 384 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
| 385 |
image = np.array(input_image['image'])
|
| 386 |
sam_predictor.set_image(image)
|
|
@@ -416,15 +574,15 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 416 |
os.remove(image_path)
|
| 417 |
output_images.append(segment_image_result)
|
| 418 |
|
| 419 |
-
logger.info(f'
|
| 420 |
if task_type == 'detection' or task_type == 'segment':
|
| 421 |
-
logger.info(f'
|
| 422 |
return output_images, gr.Gallery.update(label='result images')
|
| 423 |
elif task_type == 'inpainting' or task_type == 'remove':
|
| 424 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
| 425 |
task_type = 'remove'
|
| 426 |
|
| 427 |
-
logger.info(f'
|
| 428 |
if mask_source_radio == mask_source_draw:
|
| 429 |
mask_pil = input_mask_pil
|
| 430 |
mask = input_mask
|
|
@@ -437,6 +595,8 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 437 |
mask_pil = Image.fromarray(mask)
|
| 438 |
|
| 439 |
image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
|
|
|
|
|
|
| 440 |
mask_pil.convert("RGB").save(image_path)
|
| 441 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 442 |
os.remove(image_path)
|
|
@@ -480,6 +640,8 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 480 |
mask_pil = mix_masks(mask_imgs)
|
| 481 |
|
| 482 |
image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
|
|
|
|
|
|
| 483 |
mask_pil.convert("RGB").save(image_path)
|
| 484 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 485 |
os.remove(image_path)
|
|
@@ -492,25 +654,35 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
|
|
| 492 |
image_inpainting.save(image_path)
|
| 493 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 494 |
os.remove(image_path)
|
| 495 |
-
logger.info(f'
|
| 496 |
output_images.append(image_result)
|
| 497 |
return output_images, gr.Gallery.update(label='result images')
|
| 498 |
else:
|
| 499 |
logger.info(f"task_type:{task_type} error!")
|
| 500 |
-
logger.info(f'
|
| 501 |
return output_images, gr.Gallery.update(label='result images')
|
| 502 |
|
| 503 |
-
def change_radio_display(task_type, mask_source_radio):
|
| 504 |
text_prompt_visible = True
|
| 505 |
inpaint_prompt_visible = False
|
| 506 |
mask_source_radio_visible = False
|
|
|
|
|
|
|
|
|
|
| 507 |
if task_type == "inpainting":
|
| 508 |
inpaint_prompt_visible = True
|
| 509 |
if task_type == "inpainting" or task_type == "remove":
|
| 510 |
mask_source_radio_visible = True
|
| 511 |
if mask_source_radio == mask_source_draw:
|
| 512 |
text_prompt_visible = False
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
if __name__ == "__main__":
|
| 516 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
@@ -525,15 +697,16 @@ if __name__ == "__main__":
|
|
| 525 |
with gr.Row():
|
| 526 |
with gr.Column():
|
| 527 |
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
|
| 528 |
-
task_type = gr.Radio(["detection", "segment", "inpainting", "remove"], value="detection",
|
| 529 |
-
label='Task type',
|
| 530 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 531 |
value=mask_source_segment, label="Mask from",
|
| 532 |
-
|
| 533 |
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
|
| 534 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
|
|
|
| 535 |
run_button = gr.Button(label="Run")
|
| 536 |
-
with gr.Accordion("Advanced options", open=False):
|
| 537 |
box_threshold = gr.Slider(
|
| 538 |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
| 539 |
)
|
|
@@ -551,14 +724,16 @@ if __name__ == "__main__":
|
|
| 551 |
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
| 552 |
|
| 553 |
with gr.Column():
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
).style(grid=[2], full_width=True, full_height=True)
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
| 562 |
|
| 563 |
DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Thanks for their excellent work.'
|
| 564 |
DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
|
|
|
| 44 |
from lama_cleaner.schema import Config
|
| 45 |
|
| 46 |
# segment anything
|
| 47 |
+
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
| 48 |
|
| 49 |
# diffusers
|
| 50 |
import PIL
|
|
|
|
| 238 |
# initialize SAM
|
| 239 |
logger.info(f"initialize SAM model...")
|
| 240 |
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
|
| 241 |
+
sam_mask_generator = SamAutomaticMaskGenerator(sam_predictor)
|
| 242 |
|
| 243 |
# initialize stable-diffusion-inpainting
|
| 244 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
|
|
|
| 320 |
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
| 321 |
return image
|
| 322 |
|
| 323 |
+
# relate anything
|
| 324 |
+
from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, show_mask
|
| 325 |
+
from ram_train_eval import RamModel,RamPredictor
|
| 326 |
+
from mmengine.config import Config
|
| 327 |
+
input_size = 512
|
| 328 |
+
hidden_size = 256
|
| 329 |
+
num_classes = 56
|
| 330 |
+
|
| 331 |
+
# load ram model
|
| 332 |
+
model_path = "./checkpoints/ram_epoch12.pth"
|
| 333 |
+
config = dict(
|
| 334 |
+
model=dict(
|
| 335 |
+
pretrained_model_name_or_path='bert-base-uncased',
|
| 336 |
+
load_pretrained_weights=False,
|
| 337 |
+
num_transformer_layer=2,
|
| 338 |
+
input_feature_size=256,
|
| 339 |
+
output_feature_size=768,
|
| 340 |
+
cls_feature_size=512,
|
| 341 |
+
num_relation_classes=56,
|
| 342 |
+
pred_type='attention',
|
| 343 |
+
loss_type='multi_label_ce',
|
| 344 |
+
),
|
| 345 |
+
load_from=model_path,
|
| 346 |
+
)
|
| 347 |
+
config = Config(config)
|
| 348 |
+
|
| 349 |
+
class Predictor(RamPredictor, device='cpu'):
|
| 350 |
+
def __init__(self,config):
|
| 351 |
+
self.config = config
|
| 352 |
+
self.device = torch.device(device)
|
| 353 |
+
self._build_model()
|
| 354 |
+
|
| 355 |
+
def _build_model(self):
|
| 356 |
+
self.model = RamModel(**self.config.model).to(self.device)
|
| 357 |
+
if self.config.load_from is not None:
|
| 358 |
+
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
| 359 |
+
self.model.train()
|
| 360 |
+
ram_model = Predictor(config, device)
|
| 361 |
+
|
| 362 |
+
# visualization
|
| 363 |
+
def draw_selected_mask(mask, draw):
|
| 364 |
+
color = (255, 0, 0, 153)
|
| 365 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
| 366 |
+
for coord in nonzero_coords:
|
| 367 |
+
draw.point(coord[::-1], fill=color)
|
| 368 |
+
|
| 369 |
+
def draw_object_mask(mask, draw):
|
| 370 |
+
color = (0, 0, 255, 153)
|
| 371 |
+
nonzero_coords = np.transpose(np.nonzero(mask))
|
| 372 |
+
for coord in nonzero_coords:
|
| 373 |
+
draw.point(coord[::-1], fill=color)
|
| 374 |
+
|
| 375 |
+
def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'):
|
| 376 |
+
# Define the colors to use for each word
|
| 377 |
+
color_red = (255, 0, 0)
|
| 378 |
+
color_black = (0, 0, 0)
|
| 379 |
+
color_blue = (0, 0, 255)
|
| 380 |
+
|
| 381 |
+
# Define the initial font size and spacing between words
|
| 382 |
+
font_size = 40
|
| 383 |
+
|
| 384 |
+
# Create a new image with the specified width and white background
|
| 385 |
+
image = Image.new('RGB', (width, 60), (255, 255, 255))
|
| 386 |
+
|
| 387 |
+
# Load the specified font
|
| 388 |
+
font = ImageFont.truetype(font_path, font_size)
|
| 389 |
+
|
| 390 |
+
# Keep increasing the font size until all words fit within the desired width
|
| 391 |
+
while True:
|
| 392 |
+
# Create a draw object for the image
|
| 393 |
+
draw = ImageDraw.Draw(image)
|
| 394 |
+
|
| 395 |
+
word_spacing = font_size / 2
|
| 396 |
+
# Draw each word in the appropriate color
|
| 397 |
+
x_offset = word_spacing
|
| 398 |
+
draw.text((x_offset, 0), word1, color_red, font=font)
|
| 399 |
+
x_offset += font.getsize(word1)[0] + word_spacing
|
| 400 |
+
draw.text((x_offset, 0), word2, color_black, font=font)
|
| 401 |
+
x_offset += font.getsize(word2)[0] + word_spacing
|
| 402 |
+
draw.text((x_offset, 0), word3, color_blue, font=font)
|
| 403 |
+
|
| 404 |
+
word_sizes = [font.getsize(word) for word in [word1, word2, word3]]
|
| 405 |
+
total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3
|
| 406 |
+
|
| 407 |
+
# Stop increasing font size if the image is within the desired width
|
| 408 |
+
if total_width <= width:
|
| 409 |
+
break
|
| 410 |
+
|
| 411 |
+
# Increase font size and reset the draw object
|
| 412 |
+
font_size -= 1
|
| 413 |
+
image = Image.new('RGB', (width, 50), (255, 255, 255))
|
| 414 |
+
font = ImageFont.truetype(font_path, font_size)
|
| 415 |
+
draw = None
|
| 416 |
+
|
| 417 |
+
return image
|
| 418 |
+
|
| 419 |
+
def concatenate_images_vertical(image1, image2):
|
| 420 |
+
# Get the dimensions of the two images
|
| 421 |
+
width1, height1 = image1.size
|
| 422 |
+
width2, height2 = image2.size
|
| 423 |
+
|
| 424 |
+
# Create a new image with the combined height and the maximum width
|
| 425 |
+
new_image = Image.new('RGBA', (max(width1, width2), height1 + height2))
|
| 426 |
+
|
| 427 |
+
# Paste the first image at the top of the new image
|
| 428 |
+
new_image.paste(image1, (0, 0))
|
| 429 |
+
|
| 430 |
+
# Paste the second image below the first image
|
| 431 |
+
new_image.paste(image2, (0, height1))
|
| 432 |
+
|
| 433 |
+
return new_image
|
| 434 |
+
|
| 435 |
+
def relate_anything(input_image, k):
|
| 436 |
+
w, h = input_image.size
|
| 437 |
+
max_edge = 1500
|
| 438 |
+
if w > max_edge or h > max_edge:
|
| 439 |
+
ratio = max(w, h) / max_edge
|
| 440 |
+
new_size = (int(w / ratio), int(h / ratio))
|
| 441 |
+
input_image.thumbnail(new_size)
|
| 442 |
+
|
| 443 |
+
# load image
|
| 444 |
+
pil_image = input_image.convert('RGBA')
|
| 445 |
+
image = np.array(input_image)
|
| 446 |
+
sam_masks = sam_mask_generator.generate(image)
|
| 447 |
+
filtered_masks = sort_and_deduplicate(sam_masks)
|
| 448 |
+
|
| 449 |
+
feat_list = []
|
| 450 |
+
for fm in filtered_masks:
|
| 451 |
+
feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
|
| 452 |
+
feat_list.append(feat)
|
| 453 |
+
feat = torch.cat(feat_list, dim=1).to(device)
|
| 454 |
+
matrix_output, rel_triplets = ram_model.predict(feat)
|
| 455 |
+
|
| 456 |
+
pil_image_list = []
|
| 457 |
+
for i, rel in enumerate(rel_triplets[:k]):
|
| 458 |
+
s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
|
| 459 |
+
relation = relation_classes[r]
|
| 460 |
+
|
| 461 |
+
mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
|
| 462 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
| 463 |
+
|
| 464 |
+
draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
|
| 465 |
+
draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
|
| 466 |
+
|
| 467 |
+
current_pil_image = pil_image.copy()
|
| 468 |
+
current_pil_image.alpha_composite(mask_image)
|
| 469 |
+
|
| 470 |
+
title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
|
| 471 |
+
concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
|
| 472 |
+
pil_image_list.append(concate_pil_image)
|
| 473 |
+
|
| 474 |
+
yield pil_image_list
|
| 475 |
+
|
| 476 |
+
|
| 477 |
mask_source_draw = "draw a mask on input image"
|
| 478 |
mask_source_segment = "type what to detect below"
|
| 479 |
|
| 480 |
+
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
| 481 |
+
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
|
| 482 |
+
if task_type == "relate anything":
|
| 483 |
+
return relate_anything(input_image['image'], num_relation)
|
| 484 |
+
|
| 485 |
text_prompt = text_prompt.strip()
|
| 486 |
if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
|
| 487 |
if text_prompt == '':
|
|
|
|
| 491 |
return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂')
|
| 492 |
|
| 493 |
file_temp = int(time.time())
|
| 494 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
|
| 495 |
|
| 496 |
# load image
|
| 497 |
input_mask_pil = input_image['mask']
|
|
|
|
| 522 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
| 523 |
)
|
| 524 |
if boxes_filt.size(0) == 0:
|
| 525 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1_[No objects detected, please try others.]_')
|
| 526 |
return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂')
|
| 527 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
| 528 |
|
|
|
|
| 538 |
os.remove(image_path)
|
| 539 |
output_images.append(detection_image_result)
|
| 540 |
|
| 541 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
|
| 542 |
if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
|
| 543 |
image = np.array(input_image['image'])
|
| 544 |
sam_predictor.set_image(image)
|
|
|
|
| 574 |
os.remove(image_path)
|
| 575 |
output_images.append(segment_image_result)
|
| 576 |
|
| 577 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
|
| 578 |
if task_type == 'detection' or task_type == 'segment':
|
| 579 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 580 |
return output_images, gr.Gallery.update(label='result images')
|
| 581 |
elif task_type == 'inpainting' or task_type == 'remove':
|
| 582 |
if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
|
| 583 |
task_type = 'remove'
|
| 584 |
|
| 585 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
|
| 586 |
if mask_source_radio == mask_source_draw:
|
| 587 |
mask_pil = input_mask_pil
|
| 588 |
mask = input_mask
|
|
|
|
| 595 |
mask_pil = Image.fromarray(mask)
|
| 596 |
|
| 597 |
image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
| 598 |
+
# if reverse_mask:
|
| 599 |
+
# mask_pil = mask_pil.point(lambda _: 255-_)
|
| 600 |
mask_pil.convert("RGB").save(image_path)
|
| 601 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 602 |
os.remove(image_path)
|
|
|
|
| 640 |
mask_pil = mix_masks(mask_imgs)
|
| 641 |
|
| 642 |
image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
|
| 643 |
+
# if reverse_mask:
|
| 644 |
+
# mask_pil = mask_pil.point(lambda _: 255-_)
|
| 645 |
mask_pil.convert("RGB").save(image_path)
|
| 646 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 647 |
os.remove(image_path)
|
|
|
|
| 654 |
image_inpainting.save(image_path)
|
| 655 |
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
| 656 |
os.remove(image_path)
|
| 657 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
|
| 658 |
output_images.append(image_result)
|
| 659 |
return output_images, gr.Gallery.update(label='result images')
|
| 660 |
else:
|
| 661 |
logger.info(f"task_type:{task_type} error!")
|
| 662 |
+
logger.info(f'run_anything_task_[{file_temp}]_9_9_')
|
| 663 |
return output_images, gr.Gallery.update(label='result images')
|
| 664 |
|
| 665 |
+
def change_radio_display(task_type, mask_source_radio, num_relation): #, gsa_gallery, ram_gallery):
|
| 666 |
text_prompt_visible = True
|
| 667 |
inpaint_prompt_visible = False
|
| 668 |
mask_source_radio_visible = False
|
| 669 |
+
num_relation_visible = False
|
| 670 |
+
# gsa_gallery_visible = True
|
| 671 |
+
# ram_gallery_visible = False
|
| 672 |
if task_type == "inpainting":
|
| 673 |
inpaint_prompt_visible = True
|
| 674 |
if task_type == "inpainting" or task_type == "remove":
|
| 675 |
mask_source_radio_visible = True
|
| 676 |
if mask_source_radio == mask_source_draw:
|
| 677 |
text_prompt_visible = False
|
| 678 |
+
if task_type == "relate anything":
|
| 679 |
+
text_prompt_visible = False
|
| 680 |
+
num_relation_visible = True
|
| 681 |
+
# gsa_gallery_visible = False
|
| 682 |
+
# ram_gallery_visible = True
|
| 683 |
+
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible),
|
| 684 |
+
gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
| 685 |
+
# gr.Gallery.update(visible=gas_gallery_visible), gr.Gallery.update(visible=ram_gallery_visible)
|
| 686 |
|
| 687 |
if __name__ == "__main__":
|
| 688 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
|
|
| 697 |
with gr.Row():
|
| 698 |
with gr.Column():
|
| 699 |
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload")
|
| 700 |
+
task_type = gr.Radio(["detection", "segment", "inpainting", "remove", "relate anything"], value="detection",
|
| 701 |
+
label='Task type', visible=True)
|
| 702 |
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
| 703 |
value=mask_source_segment, label="Mask from",
|
| 704 |
+
visible=False)
|
| 705 |
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
|
| 706 |
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
| 707 |
+
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
| 708 |
run_button = gr.Button(label="Run")
|
| 709 |
+
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
| 710 |
box_threshold = gr.Slider(
|
| 711 |
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
| 712 |
)
|
|
|
|
| 724 |
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
| 725 |
|
| 726 |
with gr.Column():
|
| 727 |
+
# gsa_gallery = gr.Gallery(
|
| 728 |
+
# label="result images", show_label=True, elem_id="gsa_gallery"
|
| 729 |
+
# ).style(grid=[2], full_width=True, full_height=True)
|
| 730 |
+
gallery = gr.Gallery(label="Your Result", show_label=True, elem_id="gallery").style(preview=True, columns=5, object_fit="scale-down")
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
run_button.click(fn=run_anything_task, inputs=[
|
| 734 |
+
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[gsa_gallery, gsa_gallery])
|
| 735 |
+
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
| 736 |
+
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
| 737 |
|
| 738 |
DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). Thanks for their excellent work.'
|
| 739 |
DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
assets/OpenSans-Bold.ttf
ADDED
|
Binary file (225 kB). View file
|
|
|
checkpoints/ram_epoch12.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:036ddbb89e3376b61cb548c8cac3007c3ab7236fb6ac82207d4ccf4039654297
|
| 3 |
+
size 333991817
|
ram_train_eval.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from mmengine.config import Config
|
| 10 |
+
from mmengine.utils import ProgressBar
|
| 11 |
+
from transformers import AutoConfig, AutoModel
|
| 12 |
+
|
| 13 |
+
class RamDataset(torch.utils.data.Dataset):
|
| 14 |
+
def __init__(self, data_path, is_train=True, num_relation_classes=56):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.num_relation_classes = num_relation_classes
|
| 17 |
+
data = np.load(data_path, allow_pickle=True)
|
| 18 |
+
self.samples = data["arr_0"]
|
| 19 |
+
sample_num = self.samples.size
|
| 20 |
+
self.sample_idx_list = []
|
| 21 |
+
for idx in range(sample_num):
|
| 22 |
+
if self.samples[idx]["is_train"] == is_train:
|
| 23 |
+
self.sample_idx_list.append(idx)
|
| 24 |
+
|
| 25 |
+
def __getitem__(self, idx):
|
| 26 |
+
sample = self.samples[self.sample_idx_list[idx]]
|
| 27 |
+
object_num = sample["feat"].shape[0]
|
| 28 |
+
embedding = torch.from_numpy(sample["feat"])
|
| 29 |
+
gt_rels = sample["relations"]
|
| 30 |
+
rel_target = self._get_target(object_num, gt_rels)
|
| 31 |
+
return embedding, rel_target, gt_rels
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.sample_idx_list)
|
| 35 |
+
|
| 36 |
+
def _get_target(self, object_num, gt_rels):
|
| 37 |
+
rel_target = torch.zeros([self.num_relation_classes, object_num, object_num])
|
| 38 |
+
for ii, jj, cls_relationship in gt_rels:
|
| 39 |
+
rel_target[cls_relationship, ii, jj] = 1
|
| 40 |
+
return rel_target
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RamModel(nn.Module):
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
pretrained_model_name_or_path,
|
| 47 |
+
load_pretrained_weights=True,
|
| 48 |
+
num_transformer_layer=2,
|
| 49 |
+
input_feature_size=256,
|
| 50 |
+
output_feature_size=768,
|
| 51 |
+
cls_feature_size=512,
|
| 52 |
+
num_relation_classes=56,
|
| 53 |
+
pred_type="attention",
|
| 54 |
+
loss_type="bce",
|
| 55 |
+
):
|
| 56 |
+
super().__init__()
|
| 57 |
+
# 0. config
|
| 58 |
+
self.cls_feature_size = cls_feature_size
|
| 59 |
+
self.num_relation_classes = num_relation_classes
|
| 60 |
+
self.pred_type = pred_type
|
| 61 |
+
self.loss_type = loss_type
|
| 62 |
+
|
| 63 |
+
# 1. fc input and output
|
| 64 |
+
self.fc_input = nn.Sequential(
|
| 65 |
+
nn.Linear(input_feature_size, output_feature_size),
|
| 66 |
+
nn.LayerNorm(output_feature_size),
|
| 67 |
+
)
|
| 68 |
+
self.fc_output = nn.Sequential(
|
| 69 |
+
nn.Linear(output_feature_size, output_feature_size),
|
| 70 |
+
nn.LayerNorm(output_feature_size),
|
| 71 |
+
)
|
| 72 |
+
# 2. transformer model
|
| 73 |
+
if load_pretrained_weights:
|
| 74 |
+
self.model = AutoModel.from_pretrained(pretrained_model_name_or_path)
|
| 75 |
+
else:
|
| 76 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
| 77 |
+
self.model = AutoModel.from_config(config)
|
| 78 |
+
if num_transformer_layer != "all" and isinstance(num_transformer_layer, int):
|
| 79 |
+
self.model.encoder.layer = self.model.encoder.layer[:num_transformer_layer]
|
| 80 |
+
# 3. predict head
|
| 81 |
+
self.cls_sub = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes)
|
| 82 |
+
self.cls_obj = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes)
|
| 83 |
+
# 4. loss
|
| 84 |
+
if self.loss_type == "bce":
|
| 85 |
+
self.bce_loss = nn.BCEWithLogitsLoss()
|
| 86 |
+
elif self.loss_type == "multi_label_ce":
|
| 87 |
+
print("Use Multi Label Cross Entropy Loss.")
|
| 88 |
+
|
| 89 |
+
def forward(self, embeds, attention_mask=None):
|
| 90 |
+
"""
|
| 91 |
+
embeds: (batch_size, token_num, feature_size)
|
| 92 |
+
attention_mask: (batch_size, token_num)
|
| 93 |
+
"""
|
| 94 |
+
# 1. fc input
|
| 95 |
+
embeds = self.fc_input(embeds)
|
| 96 |
+
# 2. transformer model
|
| 97 |
+
position_ids = torch.ones([1, embeds.shape[1]]).to(embeds.device).to(torch.long)
|
| 98 |
+
outputs = self.model.forward(inputs_embeds=embeds, attention_mask=attention_mask, position_ids=position_ids)
|
| 99 |
+
embeds = outputs["last_hidden_state"]
|
| 100 |
+
# 3. fc output
|
| 101 |
+
embeds = self.fc_output(embeds)
|
| 102 |
+
# 4. predict head
|
| 103 |
+
batch_size, token_num, feature_size = embeds.shape
|
| 104 |
+
sub_embeds = self.cls_sub(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3])
|
| 105 |
+
obj_embeds = self.cls_obj(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3])
|
| 106 |
+
if self.pred_type == "attention":
|
| 107 |
+
cls_pred = sub_embeds @ torch.transpose(obj_embeds, 2, 3) / self.cls_feature_size**0.5 # noqa
|
| 108 |
+
elif self.pred_type == "einsum":
|
| 109 |
+
cls_pred = torch.einsum("nrsc,nroc->nrso", sub_embeds, obj_embeds)
|
| 110 |
+
return cls_pred
|
| 111 |
+
|
| 112 |
+
def loss(self, pred, target, attention_mask):
|
| 113 |
+
loss_dict = dict()
|
| 114 |
+
batch_size, relation_num, _, _ = pred.shape
|
| 115 |
+
|
| 116 |
+
mask = torch.zeros_like(pred).to(pred.device)
|
| 117 |
+
for idx in range(batch_size):
|
| 118 |
+
n = torch.sum(attention_mask[idx]).to(torch.int)
|
| 119 |
+
mask[idx, :, :n, :n] = 1
|
| 120 |
+
pred = pred * mask - 9999 * (1 - mask)
|
| 121 |
+
|
| 122 |
+
if self.loss_type == "bce":
|
| 123 |
+
loss = self.bce_loss(pred, target)
|
| 124 |
+
elif self.loss_type == "multi_label_ce":
|
| 125 |
+
input_tensor = torch.permute(pred, (1, 0, 2, 3))
|
| 126 |
+
target_tensor = torch.permute(target, (1, 0, 2, 3))
|
| 127 |
+
input_tensor = pred.reshape([relation_num, -1])
|
| 128 |
+
target_tensor = target.reshape([relation_num, -1])
|
| 129 |
+
loss = self.multilabel_categorical_crossentropy(target_tensor, input_tensor)
|
| 130 |
+
weight = loss / loss.max()
|
| 131 |
+
loss = loss * weight
|
| 132 |
+
loss = loss.mean()
|
| 133 |
+
loss_dict["loss"] = loss
|
| 134 |
+
|
| 135 |
+
# running metric
|
| 136 |
+
recall_20 = get_recall_N(pred, target, object_num=20)
|
| 137 |
+
loss_dict["recall@20"] = recall_20
|
| 138 |
+
return loss_dict
|
| 139 |
+
|
| 140 |
+
def multilabel_categorical_crossentropy(self, y_true, y_pred):
|
| 141 |
+
"""
|
| 142 |
+
https://kexue.fm/archives/7359
|
| 143 |
+
"""
|
| 144 |
+
y_pred = (1 - 2 * y_true) * y_pred
|
| 145 |
+
y_pred_neg = y_pred - y_true * 9999
|
| 146 |
+
y_pred_pos = y_pred - (1 - y_true) * 9999
|
| 147 |
+
zeros = torch.zeros_like(y_pred[..., :1])
|
| 148 |
+
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
|
| 149 |
+
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
|
| 150 |
+
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
|
| 151 |
+
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
|
| 152 |
+
return neg_loss + pos_loss
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def get_recall_N(y_pred, y_true, object_num=20):
|
| 156 |
+
"""
|
| 157 |
+
y_pred: [batch_size, 56, object_num, object_num]
|
| 158 |
+
y_true: [batch_size, 56, object_num, object_num]
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
device = y_pred.device
|
| 162 |
+
recall_list = []
|
| 163 |
+
|
| 164 |
+
for idx in range(len(y_true)):
|
| 165 |
+
sample_y_true = []
|
| 166 |
+
sample_y_pred = []
|
| 167 |
+
|
| 168 |
+
# find topk
|
| 169 |
+
_, topk_indices = torch.topk(
|
| 170 |
+
y_true[idx : idx + 1].reshape(
|
| 171 |
+
[
|
| 172 |
+
-1,
|
| 173 |
+
]
|
| 174 |
+
),
|
| 175 |
+
k=object_num,
|
| 176 |
+
)
|
| 177 |
+
for index in topk_indices:
|
| 178 |
+
pred_cls = index // (y_true.shape[2] ** 2)
|
| 179 |
+
index_subject_object = index % (y_true.shape[2] ** 2)
|
| 180 |
+
pred_subject = index_subject_object // y_true.shape[2]
|
| 181 |
+
pred_object = index_subject_object % y_true.shape[2]
|
| 182 |
+
if y_true[idx, pred_cls, pred_subject, pred_object] == 0:
|
| 183 |
+
continue
|
| 184 |
+
sample_y_true.append([pred_subject, pred_object, pred_cls])
|
| 185 |
+
|
| 186 |
+
# find topk
|
| 187 |
+
_, topk_indices = torch.topk(
|
| 188 |
+
y_pred[idx : idx + 1].reshape(
|
| 189 |
+
[
|
| 190 |
+
-1,
|
| 191 |
+
]
|
| 192 |
+
),
|
| 193 |
+
k=object_num,
|
| 194 |
+
)
|
| 195 |
+
for index in topk_indices:
|
| 196 |
+
pred_cls = index // (y_pred.shape[2] ** 2)
|
| 197 |
+
index_subject_object = index % (y_pred.shape[2] ** 2)
|
| 198 |
+
pred_subject = index_subject_object // y_pred.shape[2]
|
| 199 |
+
pred_object = index_subject_object % y_pred.shape[2]
|
| 200 |
+
sample_y_pred.append([pred_subject, pred_object, pred_cls])
|
| 201 |
+
|
| 202 |
+
recall = len([x for x in sample_y_pred if x in sample_y_true]) / (len(sample_y_true) + 1e-8)
|
| 203 |
+
recall_list.append(recall)
|
| 204 |
+
|
| 205 |
+
recall = torch.tensor(recall_list).to(device).mean() * 100
|
| 206 |
+
return recall
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class RamTrainer(object):
|
| 210 |
+
def __init__(self, config):
|
| 211 |
+
self.config = config
|
| 212 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 213 |
+
self._build_dataset()
|
| 214 |
+
self._build_dataloader()
|
| 215 |
+
self._build_model()
|
| 216 |
+
self._build_optimizer()
|
| 217 |
+
self._build_lr_scheduler()
|
| 218 |
+
|
| 219 |
+
def _build_dataset(self):
|
| 220 |
+
self.dataset = RamDataset(**self.config.dataset)
|
| 221 |
+
|
| 222 |
+
def _build_dataloader(self):
|
| 223 |
+
self.dataloader = torch.utils.data.DataLoader(
|
| 224 |
+
self.dataset,
|
| 225 |
+
batch_size=self.config.dataloader.batch_size,
|
| 226 |
+
shuffle=True if self.config.dataset.is_train else False,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def _build_model(self):
|
| 230 |
+
self.model = RamModel(**self.config.model).to(self.device)
|
| 231 |
+
if self.config.load_from is not None:
|
| 232 |
+
self.model.load_state_dict(torch.load(self.config.load_from))
|
| 233 |
+
self.model.train()
|
| 234 |
+
|
| 235 |
+
def _build_optimizer(self):
|
| 236 |
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr, weight_decay=self.config.optim.weight_decay, eps=self.config.optim.eps, betas=self.config.optim.betas)
|
| 237 |
+
|
| 238 |
+
def _build_lr_scheduler(self):
|
| 239 |
+
self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.optim.lr_scheduler.step, gamma=self.config.optim.lr_scheduler.gamma)
|
| 240 |
+
|
| 241 |
+
def train(self):
|
| 242 |
+
t_start = time.time()
|
| 243 |
+
running_avg_loss = 0
|
| 244 |
+
for epoch_idx in range(self.config.num_epoch):
|
| 245 |
+
for batch_idx, batch_data in enumerate(self.dataloader):
|
| 246 |
+
batch_embeds = batch_data[0].to(torch.float32).to(self.device)
|
| 247 |
+
batch_target = batch_data[1].to(torch.float32).to(self.device)
|
| 248 |
+
attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1]))
|
| 249 |
+
batch_pred = self.model.forward(batch_embeds, attention_mask)
|
| 250 |
+
loss_dict = self.model.loss(batch_pred, batch_target, attention_mask)
|
| 251 |
+
loss = loss_dict["loss"]
|
| 252 |
+
recall_20 = loss_dict["recall@20"]
|
| 253 |
+
self.optimizer.zero_grad()
|
| 254 |
+
loss.backward()
|
| 255 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.optim.max_norm, self.config.optim.norm_type)
|
| 256 |
+
self.optimizer.step()
|
| 257 |
+
running_avg_loss += loss.item()
|
| 258 |
+
|
| 259 |
+
if batch_idx % 100 == 0:
|
| 260 |
+
t_current = time.time()
|
| 261 |
+
num_finished_step = epoch_idx * self.config.num_epoch * len(self.dataloader) + batch_idx + 1
|
| 262 |
+
num_to_do_step = (self.config.num_epoch - epoch_idx - 1) * len(self.dataloader) + (len(self.dataloader) - batch_idx - 1)
|
| 263 |
+
avg_speed = num_finished_step / (t_current - t_start)
|
| 264 |
+
eta = num_to_do_step / avg_speed
|
| 265 |
+
print(
|
| 266 |
+
"ETA={:0>8}, Epoch={}, Batch={}/{}, LR={}, Loss={:.4f}, RunningAvgLoss={:.4f}, Recall@20={:.2f}%".format(
|
| 267 |
+
str(timedelta(seconds=int(eta))), epoch_idx + 1, batch_idx, len(self.dataloader), self.lr_scheduler.get_last_lr()[0], loss.item(), running_avg_loss / num_finished_step, recall_20.item()
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
self.lr_scheduler.step()
|
| 271 |
+
if not os.path.exists(self.config.output_dir):
|
| 272 |
+
os.makedirs(self.config.output_dir)
|
| 273 |
+
save_path = os.path.join(self.config.output_dir, "epoch_{}.pth".format(epoch_idx + 1))
|
| 274 |
+
print("Save epoch={} checkpoint to {}".format(epoch_idx + 1, save_path))
|
| 275 |
+
torch.save(self.model.state_dict(), save_path)
|
| 276 |
+
return save_path
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class RamPredictor(object):
|
| 280 |
+
def __init__(self, config):
|
| 281 |
+
self.config = config
|
| 282 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 283 |
+
self._build_dataset()
|
| 284 |
+
self._build_dataloader()
|
| 285 |
+
self._build_model()
|
| 286 |
+
|
| 287 |
+
def _build_dataset(self):
|
| 288 |
+
self.dataset = RamDataset(**self.config.dataset)
|
| 289 |
+
|
| 290 |
+
def _build_dataloader(self):
|
| 291 |
+
self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.config.dataloader.batch_size, shuffle=False)
|
| 292 |
+
|
| 293 |
+
def _build_model(self):
|
| 294 |
+
self.model = RamModel(**self.config.model).to(self.device)
|
| 295 |
+
if self.config.load_from is not None:
|
| 296 |
+
self.model.load_state_dict(torch.load(self.config.load_from))
|
| 297 |
+
self.model.eval()
|
| 298 |
+
|
| 299 |
+
def predict(self, batch_embeds, pred_keep_num=100):
|
| 300 |
+
"""
|
| 301 |
+
Parameters
|
| 302 |
+
----------
|
| 303 |
+
batch_embeds: (batch_size=1, token_num, feature_size)
|
| 304 |
+
pred_keep_num: int
|
| 305 |
+
Returns
|
| 306 |
+
-------
|
| 307 |
+
batch_pred: (batch_size, relation_num, object_num, object_num)
|
| 308 |
+
pred_rels: [[sub_id, obj_id, rel_id], ...]
|
| 309 |
+
"""
|
| 310 |
+
if not isinstance(batch_embeds, torch.Tensor):
|
| 311 |
+
batch_embeds = torch.asarray(batch_embeds)
|
| 312 |
+
batch_embeds = batch_embeds.to(torch.float32).to(self.device)
|
| 313 |
+
attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1]))
|
| 314 |
+
batch_pred = self.model.forward(batch_embeds, attention_mask)
|
| 315 |
+
for idx_i in range(batch_pred.shape[2]):
|
| 316 |
+
batch_pred[:, :, idx_i, idx_i] = -9999
|
| 317 |
+
batch_pred = batch_pred.sigmoid()
|
| 318 |
+
|
| 319 |
+
pred_rels = []
|
| 320 |
+
_, topk_indices = torch.topk(
|
| 321 |
+
batch_pred.reshape(
|
| 322 |
+
[
|
| 323 |
+
-1,
|
| 324 |
+
]
|
| 325 |
+
),
|
| 326 |
+
k=pred_keep_num,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# subject, object, relation
|
| 330 |
+
for index in topk_indices:
|
| 331 |
+
pred_relation = index // (batch_pred.shape[2] ** 2)
|
| 332 |
+
index_subject_object = index % (batch_pred.shape[2] ** 2)
|
| 333 |
+
pred_subject = index_subject_object // batch_pred.shape[2]
|
| 334 |
+
pred_object = index_subject_object % batch_pred.shape[2]
|
| 335 |
+
pred = [pred_subject.item(), pred_object.item(), pred_relation.item()]
|
| 336 |
+
pred_rels.append(pred)
|
| 337 |
+
return batch_pred, pred_rels
|
| 338 |
+
|
| 339 |
+
def eval(self):
|
| 340 |
+
sum_recall_20 = 0.0
|
| 341 |
+
sum_recall_50 = 0.0
|
| 342 |
+
sum_recall_100 = 0.0
|
| 343 |
+
prog_bar = ProgressBar(len(self.dataloader))
|
| 344 |
+
for batch_idx, batch_data in enumerate(self.dataloader):
|
| 345 |
+
batch_embeds = batch_data[0]
|
| 346 |
+
batch_target = batch_data[1]
|
| 347 |
+
gt_rels = batch_data[2]
|
| 348 |
+
batch_pred, pred_rels = self.predict(batch_embeds)
|
| 349 |
+
this_recall_20 = get_recall_N(batch_pred, batch_target, object_num=20)
|
| 350 |
+
this_recall_50 = get_recall_N(batch_pred, batch_target, object_num=50)
|
| 351 |
+
this_recall_100 = get_recall_N(batch_pred, batch_target, object_num=100)
|
| 352 |
+
sum_recall_20 += this_recall_20.item()
|
| 353 |
+
sum_recall_50 += this_recall_50.item()
|
| 354 |
+
sum_recall_100 += this_recall_100.item()
|
| 355 |
+
prog_bar.update()
|
| 356 |
+
recall_20 = sum_recall_20 / len(self.dataloader)
|
| 357 |
+
recall_50 = sum_recall_50 / len(self.dataloader)
|
| 358 |
+
recall_100 = sum_recall_100 / len(self.dataloader)
|
| 359 |
+
metric = {
|
| 360 |
+
"recall_20": recall_20,
|
| 361 |
+
"recall_50": recall_50,
|
| 362 |
+
"recall_100": recall_100,
|
| 363 |
+
}
|
| 364 |
+
return metric
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
if __name__ == "__main__":
|
| 368 |
+
# Config
|
| 369 |
+
config = dict(
|
| 370 |
+
dataset=dict(
|
| 371 |
+
data_path="./data/feat_0420.npz",
|
| 372 |
+
is_train=True,
|
| 373 |
+
num_relation_classes=56,
|
| 374 |
+
),
|
| 375 |
+
dataloader=dict(
|
| 376 |
+
batch_size=4,
|
| 377 |
+
),
|
| 378 |
+
model=dict(
|
| 379 |
+
pretrained_model_name_or_path="bert-base-uncased",
|
| 380 |
+
load_pretrained_weights=True,
|
| 381 |
+
num_transformer_layer=2,
|
| 382 |
+
input_feature_size=256,
|
| 383 |
+
output_feature_size=768,
|
| 384 |
+
cls_feature_size=512,
|
| 385 |
+
num_relation_classes=56,
|
| 386 |
+
pred_type="attention",
|
| 387 |
+
loss_type="multi_label_ce",
|
| 388 |
+
),
|
| 389 |
+
optim=dict(
|
| 390 |
+
lr=1e-4,
|
| 391 |
+
weight_decay=0.05,
|
| 392 |
+
eps=1e-8,
|
| 393 |
+
betas=(0.9, 0.999),
|
| 394 |
+
max_norm=0.01,
|
| 395 |
+
norm_type=2,
|
| 396 |
+
lr_scheduler=dict(
|
| 397 |
+
step=[6, 10],
|
| 398 |
+
gamma=0.1,
|
| 399 |
+
),
|
| 400 |
+
),
|
| 401 |
+
num_epoch=12,
|
| 402 |
+
output_dir="./work_dirs",
|
| 403 |
+
load_from=None,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Train
|
| 407 |
+
config = Config(config)
|
| 408 |
+
trainer = RamTrainer(config)
|
| 409 |
+
last_model_path = trainer.train()
|
| 410 |
+
|
| 411 |
+
# Test/Eval
|
| 412 |
+
config.dataset.is_train = False
|
| 413 |
+
config.load_from = last_model_path
|
| 414 |
+
predictor = RamPredictor(config)
|
| 415 |
+
metric = predictor.eval()
|
| 416 |
+
print(metric)
|
ram_utils.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MLP(nn.Module):
|
| 9 |
+
def __init__(self, input_size, hidden_size, num_classes, dropout_prob=0.1):
|
| 10 |
+
super(MLP, self).__init__()
|
| 11 |
+
self.fc1 = nn.Linear(input_size, hidden_size)
|
| 12 |
+
self.relu = nn.ReLU()
|
| 13 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 14 |
+
self.fc2 = nn.Linear(hidden_size, num_classes)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
out = self.fc1(x)
|
| 18 |
+
out = self.relu(out)
|
| 19 |
+
out = self.dropout(out)
|
| 20 |
+
out = self.fc2(out)
|
| 21 |
+
return out
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def show_anns(anns, color_code='auto'):
|
| 25 |
+
if len(anns) == 0:
|
| 26 |
+
return
|
| 27 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
| 28 |
+
ax = plt.gca()
|
| 29 |
+
ax.set_autoscale_on(False)
|
| 30 |
+
polygons = []
|
| 31 |
+
color = []
|
| 32 |
+
for ann in sorted_anns:
|
| 33 |
+
m = ann['segmentation']
|
| 34 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
| 35 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
| 36 |
+
if color_code == 'auto':
|
| 37 |
+
for i in range(3):
|
| 38 |
+
img[:,:,i] = color_mask[i]
|
| 39 |
+
elif color_code == 'red':
|
| 40 |
+
for i in range(3):
|
| 41 |
+
img[:,:,0] = 1
|
| 42 |
+
img[:,:,1] = 0
|
| 43 |
+
img[:,:,2] = 0
|
| 44 |
+
else:
|
| 45 |
+
for i in range(3):
|
| 46 |
+
img[:,:,0] = 0
|
| 47 |
+
img[:,:,1] = 0
|
| 48 |
+
img[:,:,2] = 1
|
| 49 |
+
return np.dstack((img, m*0.35))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def show_points(coords, labels, ax, marker_size=375):
|
| 53 |
+
pos_points = coords[labels==1]
|
| 54 |
+
neg_points = coords[labels==0]
|
| 55 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
|
| 56 |
+
s=marker_size, edgecolor='white', linewidth=1.25)
|
| 57 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
|
| 58 |
+
s=marker_size, edgecolor='white', linewidth=1.25)
|
| 59 |
+
|
| 60 |
+
def show_mask(m):
|
| 61 |
+
img = np.ones((m.shape[0], m.shape[1], 3))
|
| 62 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
| 63 |
+
for i in range(3):
|
| 64 |
+
img[:,:,0] = 1
|
| 65 |
+
img[:,:,1] = 0
|
| 66 |
+
img[:,:,2] = 0
|
| 67 |
+
|
| 68 |
+
return np.dstack((img, m*0.35))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def iou(mask1, mask2):
|
| 72 |
+
intersection = np.logical_and(mask1, mask2)
|
| 73 |
+
union = np.logical_or(mask1, mask2)
|
| 74 |
+
iou_score = np.sum(intersection) / np.sum(union)
|
| 75 |
+
return iou_score
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def sort_and_deduplicate(sam_masks, iou_threshold=0.8):
|
| 79 |
+
# Sort the sam_masks list based on the area value
|
| 80 |
+
sorted_masks = sorted(sam_masks, key=lambda x: x['area'], reverse=True)
|
| 81 |
+
|
| 82 |
+
# Deduplicate masks based on the given iou_threshold
|
| 83 |
+
filtered_masks = []
|
| 84 |
+
for mask in sorted_masks:
|
| 85 |
+
duplicate = False
|
| 86 |
+
for filtered_mask in filtered_masks:
|
| 87 |
+
if iou(mask['segmentation'], filtered_mask['segmentation']) > iou_threshold:
|
| 88 |
+
duplicate = True
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
if not duplicate:
|
| 92 |
+
filtered_masks.append(mask)
|
| 93 |
+
|
| 94 |
+
return filtered_masks
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
relation_classes = ['over',
|
| 98 |
+
'in front of',
|
| 99 |
+
'beside',
|
| 100 |
+
'on',
|
| 101 |
+
'in',
|
| 102 |
+
'attached to',
|
| 103 |
+
'hanging from',
|
| 104 |
+
'on back of',
|
| 105 |
+
'falling off',
|
| 106 |
+
'going down',
|
| 107 |
+
'painted on',
|
| 108 |
+
'walking on',
|
| 109 |
+
'running on',
|
| 110 |
+
'crossing',
|
| 111 |
+
'standing on',
|
| 112 |
+
'lying on',
|
| 113 |
+
'sitting on',
|
| 114 |
+
'flying over',
|
| 115 |
+
'jumping over',
|
| 116 |
+
'jumping from',
|
| 117 |
+
'wearing',
|
| 118 |
+
'holding',
|
| 119 |
+
'carrying',
|
| 120 |
+
'looking at',
|
| 121 |
+
'guiding',
|
| 122 |
+
'kissing',
|
| 123 |
+
'eating',
|
| 124 |
+
'drinking',
|
| 125 |
+
'feeding',
|
| 126 |
+
'biting',
|
| 127 |
+
'catching',
|
| 128 |
+
'picking',
|
| 129 |
+
'playing with',
|
| 130 |
+
'chasing',
|
| 131 |
+
'climbing',
|
| 132 |
+
'cleaning',
|
| 133 |
+
'playing',
|
| 134 |
+
'touching',
|
| 135 |
+
'pushing',
|
| 136 |
+
'pulling',
|
| 137 |
+
'opening',
|
| 138 |
+
'cooking',
|
| 139 |
+
'talking to',
|
| 140 |
+
'throwing',
|
| 141 |
+
'slicing',
|
| 142 |
+
'driving',
|
| 143 |
+
'riding',
|
| 144 |
+
'parked on',
|
| 145 |
+
'driving on',
|
| 146 |
+
'about to hit',
|
| 147 |
+
'kicking',
|
| 148 |
+
'swinging',
|
| 149 |
+
'entering',
|
| 150 |
+
'exiting',
|
| 151 |
+
'enclosing',
|
| 152 |
+
'leaning on',]
|
requirements.txt
CHANGED
|
@@ -22,11 +22,6 @@ yapf
|
|
| 22 |
numba
|
| 23 |
segment_anything
|
| 24 |
|
| 25 |
-
# ftfy
|
| 26 |
-
# uuid
|
| 27 |
-
# psutil
|
| 28 |
-
# facexlib
|
| 29 |
lama-cleaner==0.25.0
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 22 |
numba
|
| 23 |
segment_anything
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
lama-cleaner==0.25.0
|
| 26 |
+
openmim==0.1.5
|
| 27 |
+
mmcv==2.0.0
|
|
|