In [32]:
import os
import argparse
import sys
import opts
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import textwrap

from PIL import Image, ImageDraw
import json
import numpy as np
from mbench.ytvos_ref import build as build_ytvos_ref

In [26]:
img_folder = 'data/ref-youtube-vos/train'
text_colors = ['red', 'blue']

In [2]:
with open('mbench/result_revised50.json') as file:
    data = json.load(file)

In [24]:
def bounding_box(img):
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    return rmin, rmax, cmin, cmax # y1, y2, x1, x2 

In [97]:
def showImageRef(vid_id):
    vid_data = data[vid_id]
    cats = list(vid_data.keys())

    for cat in cats:
        cat_data = vid_data[cat]
        frames = list(cat_data.keys())
        
        for frame in frames:
            frame_data = cat_data[frame]
            
            img_path = os.path.join(img_folder, 'JPEGImages', vid_id, frame + '.jpg')
            mask_path = os.path.join(img_folder, 'Annotations', vid_id, frame + '.png')
            img = Image.open(img_path).convert('RGB')
            mask = Image.open(mask_path).convert('P')
            mask = np.array(mask)
            
            if frame_data:
                obj_ids = list(frame_data.keys())
                obj_nums = len(obj_ids)

                fig, axes = plt.subplots(1, obj_nums, figsize=(16, obj_nums))

                for i in range(len(obj_ids)):
                    obj_id = obj_ids[i]
                    obj_data = frame_data[obj_id]
                    if obj_data:
                        ref_exp = obj_data['ref_exp']
                        isValid = obj_data['isValid']

                        obj_mask = (mask == int(obj_id)).astype(np.float32)
                        if (obj_mask > 0).any():
                            y1, y2, x1, x2 = bounding_box(obj_mask)
                            box = np.array([x1, y1, x2, y2])
                        else:
                            box = np.array([0, 0, 0, 0])
                        
                        if obj_nums == 1:
                            ax = axes
                        else:
                            ax = axes[i]
                        ax.imshow(img)
                        width, height = box[2] - box[0], box[3] - box[1]
                        rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor='red', facecolor='none')
                        ax.add_patch(rect)

                        wrapped_text = "\n".join(textwrap.wrap(ref_exp, width=30))
                        ax.annotate(wrapped_text, xy=(0.5, -1.5), xycoords="axes fraction", ha = "center", color=text_colors[isValid])
                
                plt.suptitle(f"video: {vid_id} - cat: {cat} - frame: {frame}")
                plt.show()

In [142]:
vid_id = list(data.keys())[49]
print(vid_id)
showImageRef(vid_id)

04667fabaa
