VRIS_vip / mbench /check_image_numbered_cy_score.py
dianecy's picture
Add files using upload-large-folder tool
2c58401 verified
raw
history blame
8.64 kB
import sys
import os
import argparse
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import opts
import numpy as np
import cv2
from PIL import Image
import json
from mbench.ytvos_ref import build as build_ytvos_ref
import t2v_metrics
import matplotlib.pyplot as plt
import textwrap
def scoreCaption(idx, all_captions, all_valid_obj_ids, clip_flant5_score, color_mask = False):
vid_meta = metas[idx]
vid_id = vid_meta['video']
frames = vid_meta['frames']
first_cat = list(all_captions[vid_id].keys())[0]
sampled_frames = list(all_captions[vid_id][first_cat].keys())
imgs = []
masks = []
for frame_indx in sampled_frames:
frame_name = frames[int(frame_indx)]
img_path = os.path.join(str(train_dataset.img_folder), 'JPEGImages', vid_id, frame_name + '.jpg')
mask_path = os.path.join(str(train_dataset.img_folder), 'Annotations', vid_id, frame_name + '.png')
img = Image.open(img_path).convert('RGB')
imgs.append(img)
mask = Image.open(mask_path).convert('P')
mask = np.array(mask)
masks.append(mask)
vid_captions = all_captions[vid_id]
cat_names = set(list(vid_captions.keys()))
vid_result = {}
for cat in cat_names:
cat_captions = vid_captions[cat]
cat_result = {}
for i in range(len(imgs)):
frame_name = sampled_frames[i]
frame = np.copy(np.array(imgs[i]))
frame_for_contour = np.copy(np.array(imgs[i]))
mask = masks[i]
all_obj_ids = np.unique(mask).astype(int)
all_obj_ids = [str(obj_id) for obj_id in all_obj_ids if obj_id != 0]
if cat in all_valid_obj_ids[vid_id]:
valid_obj_ids = all_valid_obj_ids[vid_id][cat]
else:
valid_obj_ids = []
for j in range(len(all_obj_ids)):
obj_id = all_obj_ids[j]
obj_mask = (mask == int(obj_id)).astype(np.uint8)
if obj_id in valid_obj_ids:
if color_mask == False:
contours, _ = cv2.findContours(obj_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(frame, contours, -1, colors[j], 3)
for i, contour in enumerate(contours):
# ์œค๊ณฝ์„  ์ค‘์‹ฌ ๊ณ„์‚ฐ
moments = cv2.moments(contour)
if moments["m00"] != 0: # ์ค‘์‹ฌ ๊ณ„์‚ฐ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ
cx = int(moments["m10"] / moments["m00"])
cy = int(moments["m01"] / moments["m00"])
else:
cx, cy = contour[0][0] # ์ค‘์‹ฌ ๊ณ„์‚ฐ ๋ถˆ๊ฐ€์‹œ ๋Œ€์ฒด ์ขŒํ‘œ ์‚ฌ์šฉ
# ํ…์ŠคํŠธ ๋ฐฐ๊ฒฝ (๊ฒ€์€์ƒ‰ ๋ฐฐ๊ฒฝ ๋งŒ๋“ค๊ธฐ)
font = cv2.FONT_HERSHEY_SIMPLEX
text = obj_id
text_size = cv2.getTextSize(text, font, 1, 2)[0]
text_w, text_h = text_size
# ํ…์ŠคํŠธ ๋ฐฐ๊ฒฝ ๊ทธ๋ฆฌ๊ธฐ (๊ฒ€์€์ƒ‰ ๋ฐฐ๊ฒฝ)
cv2.rectangle(frame, (cx - text_w // 2 - 5, cy - text_h // 2 - 5),
(cx + text_w // 2 + 5, cy + text_h // 2 + 5), (0, 0, 0), -1)
# ํ…์ŠคํŠธ ๊ทธ๋ฆฌ๊ธฐ (ํฐ์ƒ‰ ํ…์ŠคํŠธ)
cv2.putText(frame, text, (cx - text_w // 2, cy + text_h // 2),
font, 1, (255, 255, 255), 2)
else:
alpha = 0.08
colored_obj_mask = np.zeros_like(frame)
colored_obj_mask[obj_mask == 1] = colors[j]
frame[obj_mask == 1] = (
(1 - alpha) * frame[obj_mask == 1]
+ alpha * colored_obj_mask[obj_mask == 1]
)
contours, _ = cv2.findContours(obj_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(frame, contours, -1, colors[j], 2)
cv2.drawContours(frame_for_contour, contours, -1, colors[j], 2)
if len(contours) > 0:
largest_contour = max(contours, key=cv2.contourArea)
M = cv2.moments(largest_contour)
if M["m00"] != 0:
center_x = int(M["m10"] / M["m00"])
center_y = int(M["m01"] / M["m00"])
else:
center_x, center_y = 0, 0
font = cv2.FONT_HERSHEY_SIMPLEX
text = obj_id
font_scale = 0.9
text_size = cv2.getTextSize(text, font, font_scale, 2)[0]
text_x = center_x - text_size[0] // 1 # ํ…์ŠคํŠธ์˜ ๊ฐ€๋กœ ์ค‘์‹ฌ
text_y = center_y
# text_y = center_y + text_size[1] // 2 # ํ…์ŠคํŠธ์˜ ์„ธ๋กœ ์ค‘์‹ฌ
# ํ…์ŠคํŠธ ๋ฐฐ๊ฒฝ ์‚ฌ๊ฐํ˜• ์ขŒํ‘œ ๊ณ„์‚ฐ
rect_start = (text_x - 5, text_y - text_size[1] - 5) # ๋ฐฐ๊ฒฝ ์‚ฌ๊ฐํ˜• ์ขŒ์ƒ๋‹จ
# rect_end = (text_x + text_size[0] + 5, text_y + 5)
rect_end = (text_x + text_size[0] + 5, text_y)
cv2.rectangle(frame, rect_start, rect_end, (0, 0, 0), -1)
cv2.putText(frame, text, (text_x, text_y), font, 1, (255, 255, 255), 2)
# fig, ax = plt.subplots()
# ax.imshow(frame)
# ax.axis('off')
frame_caption = cat_captions[frame_name]
if frame_caption:
# wrapped_text = "\n".join(textwrap.wrap(frame_caption, width=60))
# ax.text(0.5, -0.3, wrapped_text, ha='center', va='center', fontsize=12, transform=ax.transAxes)
#calculate vqa score
frame = Image.fromarray(frame)
score = clip_flant5_score(images=[frame], texts=[frame_caption])
else:
score = None
# plt.title(f"vid_id: {vid_id}, cat: {cat}, frame: {frame_name}, score: {score}")
# plt.tight_layout()
# plt.show()
cat_result[frame_name] = {
"caption" : frame_caption,
"score" : score
}
vid_result[cat] = cat_result
return vid_id, vid_result
if __name__ == '__main__':
parser = argparse.ArgumentParser('ReferFormer training and evaluation script', parents=[opts.get_args_parser()])
args = parser.parse_args()
#==================๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ===================
# ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹
train_dataset = build_ytvos_ref(image_set = 'train', args = args)
# ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
metas = train_dataset.metas
# caption ๋ฐ์ดํ„ฐ
with open('mbench/numbered_captions_gpt-4o_final.json', 'r') as file:
all_captions = json.load(file)
# valid obj ids ๋ฐ์ดํ„ฐ
with open('mbench/numbered_valid_obj_ids_gpt-4o_final.json', 'r') as file:
all_valid_obj_ids = json.load(file)
# ์ƒ‰์ƒ ํ›„๋ณด 8๊ฐœ (RGB ํ˜•์‹)
colors = [
(255, 0, 0), # Red
(0, 255, 0), # Green
(0, 0, 255), # Blue
(255, 255, 0), # Yellow
(255, 0, 255), # Magenta
(0, 255, 255), # Cyan
(128, 0, 128), # Purple
(255, 165, 0) # Orange
]
#==================vqa score ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ===================
clip_flant5_score = t2v_metrics.VQAScore(model='clip-flant5-xxl')
#==================vqa score ์ ์ˆ˜ ๊ณ„์‚ฐํ•˜๊ธฐ===================
all_scores = {}
for i in range(5):
vid_id, vid_result = scoreCaption(i, all_captions, all_valid_obj_ids, clip_flant5_score, False)
all_scores[vid_id] = vid_result
with open('mbench/numbered_captions_gpt-4o_final_scores.json', 'w', encoding='utf-8') as json_file:
json.dump(all_scores, indent=4, ensure_ascii=False)
print("JSON ํŒŒ์ผ์ด ์„ฑ๊ณต์ ์œผ๋กœ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")