CV-Agent / utils.py
Samarth991's picture
added object detection imporved
3945649
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm
import torch
import cv2
import random
def draw_panoptic_segmentation(model,segmentation, segments_info):
# get the used color map
viridis = cm.get_cmap('viridis', torch.max(segmentation))
fig, ax = plt.subplots()
ax.imshow(segmentation.cpu().numpy())
instances_counter = defaultdict(int)
handles = []
# for each segment, draw its legend
for segment in segments_info:
segment_id = segment['id']
segment_label_id = segment['label_id']
segment_label = model.config.id2label[segment_label_id]
label = f"{segment_label}-{instances_counter[segment_label_id]}"
instances_counter[segment_label_id] += 1
color = viridis(segment_id)
handles.append(mpatches.Patch(color=color, label=label))
# ax.legend(handles=handles)
fig.savefig('final_mask.png')
return 'final_mask.png'
def draw_bboxes(rgb_frame,boxes,labels,line_thickness=3):
rgb_frame = cv2.imread(rgb_frame)
# rgb_frame = cv2.cvtColor(rgb_frame,cv2.COLOR_BGR2RGB)
tl = line_thickness or round(0.002 * (rgb_frame.shape[0] + rgb_frame.shape[1]) / 2) + 1 # line/font thickness
rgb_frame_copy = rgb_frame.copy()
color_dict = {}
# color = color or [random.randint(0, 255) for _ in range(3)]
for item in np.unique(np.asarray(labels)):
color_dict[item] = [random.randint(28, 255) for _ in range(3)]
for box,label in zip(boxes,labels):
if box.type() == 'torch.IntTensor':
box = box.numpy()
# extract coordinates
x1,y1,x2,y2 = box
c1,c2 = (x1,y1),(x2,y2)
# Draw rectangle
cv2.rectangle(rgb_frame_copy, c1,c2, color_dict[label], thickness=tl, lineType=cv2.LINE_AA)
tf = max(tl - 1, 1) # font thickness
# label = label2id[int(label.numpy())]
t_size = cv2.getTextSize(str(label), 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, color_dict[label], thickness=tf, lineType=cv2.LINE_AA)
return rgb_frame_copy