Spaces:
Running
Running
File size: 10,853 Bytes
ca37b38 615e9f1 b76c717 615e9f1 27a202c 615e9f1 27a202c 615e9f1 cc79c19 615e9f1 6ceb9bd 615e9f1 27a202c 615e9f1 ebef706 615e9f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from modules.utils import class_dict, resize_boxes, resize_keypoints, find_other_keypoint
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
from modules.OCR import group_texts
def draw_stream(image,
prediction=None,
text_predictions=None,
class_dict=class_dict,
draw_keypoints=False,
draw_boxes=False,
draw_text=False,
draw_links=False,
draw_twins=False,
draw_grouped_text=False,
write_class=False,
write_score=False,
write_text=False,
score_threshold=0.4,
write_idx=False,
keypoints_correction=False,
new_size=(1333, 1333),
only_show=None,
axis=False,
return_image=False,
resize=False):
"""
Draws annotations on images including bounding boxes, keypoints, links, and text.
Parameters:
- image (np.array): The image on which annotations will be drawn.
- target (dict): Ground truth data containing boxes, labels, etc.
- prediction (dict): Prediction data from a model.
- full_prediction (dict): Additional detailed prediction data, potentially including relationships.
- text_predictions (tuple): OCR text predictions containing bounding boxes and texts.
- class_dict (dict): Mapping from class IDs to class names.
- draw_keypoints (bool): Flag to draw keypoints.
- draw_boxes (bool): Flag to draw bounding boxes.
- draw_text (bool): Flag to draw text annotations.
- draw_links (bool): Flag to draw links between annotations.
- draw_twins (bool): Flag to draw twins keypoints.
- write_class (bool): Flag to write class names near the annotations.
- write_score (bool): Flag to write scores near the annotations.
- write_text (bool): Flag to write OCR recognized text.
- score_threshold (float): Threshold for scores above which annotations will be drawn.
- only_show (str): Specific class name to filter annotations by.
- resize (bool): Whether to resize annotations to fit the image size.
"""
#delete the global pool if it is the only one to show
"""if len(prediction['pool_dict'])==1 and prediction['labels'][-1]==6:
pool_index = list(prediction['pool_dict'])[0]
if len(prediction['pool_dict'][pool_index])==(len(prediction['boxes'])-1):
prediction['boxes'] = prediction['boxes'][:-1]
prediction['labels'] = prediction['labels'][:-1]
prediction['scores'] = prediction['scores'][:-1]
prediction['keypoints'] = prediction['keypoints'][:-1]
prediction['links'] = prediction['links'][:-1]"""
# Convert image to RGB (if not already in that format)
if prediction is None:
image = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
image_copy = image.copy()
scale = max(image.shape[0], image.shape[1]) / 1000
original_size = (image.shape[0], image.shape[1])
# Calculate scale to fit the new size while maintaining aspect ratio
scale_ = min(new_size[0] / original_size[0], new_size[1] / original_size[1])
new_scaled_size = (int(original_size[0] * scale_), int(original_size[1] * scale_))
for i in range(len(prediction['boxes'])):
if only_show is not None and only_show != 'all':
if prediction['labels'][i] != list(class_dict.values()).index(only_show):
continue
box = prediction['boxes'][i]
x1, y1, x2, y2 = box
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
score = prediction['scores'][i]
if score < score_threshold:
continue
if draw_boxes:
#dont show the lanes
if prediction['labels'][i] == list(class_dict.values()).index('lane'):
continue
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 0), int(2*scale))
if write_score:
cv2.putText(image_copy, str(round(score, 2)), (int(x1), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (100,100, 255), 2)
if write_idx:
cv2.putText(image_copy, str(i), (int(x1) + int(15*scale), int(y1) + int(15*scale)), cv2.FONT_HERSHEY_SIMPLEX, 2*scale, (0,0, 0), 2)
if write_class and 'labels' in prediction:
class_id = prediction['labels'][i]
cv2.putText(image_copy, class_dict[class_id], (int(x1), int(y1) - int(2*scale)), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (255, 100, 100), 2)
# Draw keypoints if available
if draw_keypoints and 'keypoints' in prediction:
for i in range(len(prediction['keypoints'])):
if i >= len(prediction['keypoints']):
continue
kp = prediction['keypoints'][i]
for j in range(kp.shape[0]):
if prediction['labels'][i] != list(class_dict.values()).index('sequenceFlow') and prediction['labels'][i] != list(class_dict.values()).index('messageFlow') and prediction['labels'][i] != list(class_dict.values()).index('dataAssociation'):
continue
score = prediction['scores'][i]
if score < score_threshold:
continue
x,y, v = np.array(kp[j])
x, y, v = resize_keypoints(np.array([kp[j]]), (new_scaled_size[1],new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
if j == 0:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (0, 0, 255), -1)
else:
cv2.circle(image_copy, (int(x), int(y)), int(5*scale), (255, 0, 0), -1)
# Draw text predictions if available
if (draw_text or write_text) and text_predictions is not None:
for i in range(len(text_predictions[0])):
x1, y1, x2, y2 = text_predictions[0][i]
text = text_predictions[1][i]
if resize:
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
if draw_text:
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if write_text:
cv2.putText(image_copy, text, (int(x1 + int(2*scale)), int((y1+y2)/2) ), cv2.FONT_HERSHEY_SIMPLEX, scale/2, (0,0, 0), 2)
'''Draws links between objects based on the full prediction data.'''
#check if keypoints detected are the same
if draw_twins and prediction is not None:
# Pre-calculate indices for performance
circle_color = (0, 255, 0) # Green color for the circle
circle_radius = int(10 * scale) # Circle radius scaled by image scale
for idx, (key1, key2) in enumerate(prediction['keypoints']):
if prediction['labels'][idx] not in [list(class_dict.values()).index('sequenceFlow'),
list(class_dict.values()).index('messageFlow'),
list(class_dict.values()).index('dataAssociation')]:
continue
# Calculate the Euclidean distance between the two keypoints
distance = np.linalg.norm(key1[:2] - key2[:2])
if distance < 10:
x_new,y_new, x,y = find_other_keypoint(idx,prediction)
cv2.circle(image_copy, (int(x), int(y)), circle_radius, circle_color, -1)
cv2.circle(image_copy, (int(x_new), int(y_new)), circle_radius, (0,0,0), -1)
# Draw links between objects
if draw_links==True and prediction is not None:
for i, (start_idx, end_idx) in enumerate(prediction['links']):
if start_idx is None or end_idx is None:
continue
start_box = prediction['boxes'][start_idx]
start_box = resize_boxes(np.array([start_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
end_box = prediction['boxes'][end_idx]
end_box = resize_boxes(np.array([end_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
current_box = prediction['boxes'][i]
current_box = resize_boxes(np.array([current_box]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
# Calculate the center of each bounding box
start_center = ((start_box[0] + start_box[2]) // 2, (start_box[1] + start_box[3]) // 2)
end_center = ((end_box[0] + end_box[2]) // 2, (end_box[1] + end_box[3]) // 2)
current_center = ((current_box[0] + current_box[2]) // 2, (current_box[1] + current_box[3]) // 2)
# Draw a line between the centers of the connected objects
cv2.line(image_copy, (int(start_center[0]), int(start_center[1])), (int(current_center[0]), int(current_center[1])), (0, 0, 255), int(2*scale))
cv2.line(image_copy, (int(current_center[0]), int(current_center[1])), (int(end_center[0]), int(end_center[1])), (255, 0, 0), int(2*scale))
if draw_grouped_text and prediction is not None:
task_boxes = task_boxes = [box for i, box in enumerate(prediction['boxes']) if prediction['labels'][i] == list(class_dict.values()).index('task')]
grouped_sentences, sentence_bounding_boxes, info_texts, info_boxes = group_texts(task_boxes, text_predictions[0], text_predictions[1], percentage_thresh=1)
for i in range(len(info_boxes)):
x1, y1, x2, y2 = info_boxes[i]
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
for i in range(len(sentence_bounding_boxes)):
x1,y1,x2,y2 = sentence_bounding_boxes[i]
x1, y1, x2, y2 = resize_boxes(np.array([[float(x1), float(y1), float(x2), float(y2)]]), (new_scaled_size[1], new_scaled_size[0]), (image_copy.shape[1],image_copy.shape[0]))[0]
cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), int(2*scale))
if return_image:
return image_copy
else:
# Display the image
plt.figure(figsize=(12, 12))
plt.imshow(image_copy)
if axis==False:
plt.axis('off')
plt.show() |