MnLgt commited on
Commit
95eae85
·
1 Parent(s): 0b24893

updated yolo model

Browse files
.gitignore CHANGED
@@ -1,4 +1,4 @@
1
-
2
  gradio_cached_examples/
3
  checkpoint-*
4
  */example.ipynb
 
1
+ */**.DS_Store
2
  gradio_cached_examples/
3
  checkpoint-*
4
  */example.ipynb
app.py CHANGED
@@ -1,99 +1,37 @@
1
  import gradio as gr
2
  import os
3
- from ultralytics import YOLO
4
- from yolo.BodyMask import BodyMask
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
  from matplotlib import patches
8
- from skimage.transform import resize
9
  from PIL import Image
10
  import io
 
 
 
 
11
 
12
- model_id = os.path.abspath("yolo-human-parse-epoch-125.pt")
13
-
14
-
15
- def display_image_with_masks(image, results, cols=4):
16
- # Convert PIL Image to numpy array
17
- image_np = np.array(image)
18
-
19
- # Check image dimensions
20
- if image_np.ndim != 3 or image_np.shape[2] != 3:
21
- raise ValueError("Image must be a 3-dimensional array with 3 color channels")
22
-
23
- # Number of masks
24
- n = len(results)
25
- rows = (n + cols - 1) // cols # Calculate required number of rows
26
-
27
- # Setting up the plot
28
- fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
29
- axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
30
-
31
- for i, result in enumerate(results):
32
- mask = result["mask"]
33
- label = result["label"]
34
- score = float(result["score"])
35
-
36
- # Convert PIL mask to numpy array and resize if necessary
37
- mask_np = np.array(mask)
38
- if mask_np.shape != image_np.shape[:2]:
39
- mask_np = resize(
40
- mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
41
- )
42
- mask_np = (mask_np > 0.5).astype(
43
- np.uint8
44
- ) # Threshold back to binary after resize
45
-
46
- # Create an overlay where mask is True
47
- overlay = np.zeros_like(image_np)
48
- overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
49
-
50
- # Combine the image and the overlay
51
- combined = image_np.copy()
52
- indices = np.where(mask_np > 0)
53
- combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
54
-
55
- # Show the combined image
56
- ax = axs[i]
57
- ax.imshow(combined)
58
- ax.axis("off")
59
- ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
60
- rect = patches.Rectangle(
61
- (0, 0),
62
- image_np.shape[1],
63
- image_np.shape[0],
64
- linewidth=1,
65
- edgecolor="r",
66
- facecolor="none",
67
- )
68
- ax.add_patch(rect)
69
-
70
- # Hide unused subplots if the total number of masks is not a multiple of cols
71
- for idx in range(i + 1, rows * cols):
72
- axs[idx].axis("off")
73
-
74
- plt.tight_layout()
75
-
76
- # Save the plot to a bytes buffer
77
- buf = io.BytesIO()
78
- plt.savefig(buf, format="png")
79
- buf.seek(0)
80
 
81
- # Clear the current figure
82
- plt.close(fig)
83
 
84
- return buf
85
 
 
 
 
86
 
87
- def perform_segmentation(input_image):
88
- bm = BodyMask(input_image, model_id=model_id, resize_to=640)
89
- if bm.body_mask is None:
90
- return input_image # Return the original image if no mask is found
91
- results = bm.results
92
- buf = display_image_with_masks(input_image, results)
93
 
94
- # Convert BytesIO to PIL Image
95
- img = Image.open(buf)
96
- return img
 
 
 
 
 
 
 
97
 
98
 
99
  # Get example images
 
1
  import gradio as gr
2
  import os
3
+ from hp.yolo_results import YOLOResults
 
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  from matplotlib import patches
 
7
  from PIL import Image
8
  import io
9
+ from functools import lru_cache
10
+ import logging
11
+ from ultralytics import YOLO
12
+ from hp.utils import load_resize_image
13
 
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ model_id = os.path.abspath("yolo-human-parse-v2.pt")
 
18
 
 
19
 
20
+ @lru_cache
21
+ def get_model(model_id=model_id):
22
+ return YOLO(model_id, task="segment")
23
 
 
 
 
 
 
 
24
 
25
+ def perform_segmentation(image):
26
+ model = get_model()
27
+ image = load_resize_image(image, 1024)
28
+ imgsz = max(image.size)
29
+ result = model(image, imgsz=imgsz, retina_masks=True)
30
+ if not bool(result):
31
+ logger.info("No Masks or Boxes Found")
32
+ return image
33
+ result = YOLOResults(image, result)
34
+ return result.visualize(return_image=True)
35
 
36
 
37
  # Get example images
hp/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import List, Union
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageOps
7
+ from ultralytics import YOLO
8
+
9
+ from hp.visualizer import visualizer
10
+
11
+
12
+ def resize_image_pil(image_pil, max_size=1024):
13
+ # Ensure image is in RGB
14
+ if image_pil.mode != "RGB":
15
+ image_pil = image_pil.convert("RGB")
16
+
17
+ # Calculate new dimensions preserving aspect ratio
18
+ width, height = image_pil.size
19
+ scale = min(max_size / width, max_size / height)
20
+ new_width = int(width * scale)
21
+ new_height = int(height * scale)
22
+ image_pil = image_pil.resize((new_width, new_height), Image.LANCZOS)
23
+
24
+ # Calculate padding needed to reach 1024x1024
25
+ pad_width = (max_size - new_width) // 2
26
+ pad_height = (max_size - new_height) // 2
27
+
28
+ # Apply padding symmetrically
29
+ image_pil = ImageOps.expand(
30
+ image_pil,
31
+ border=(
32
+ pad_width,
33
+ pad_height,
34
+ max_size - new_width - pad_width,
35
+ max_size - new_height - pad_height,
36
+ ),
37
+ fill=(0, 0, 0),
38
+ )
39
+
40
+ return image_pil
41
+
42
+
43
+ def load_resize_image(image_path: str | Image.Image, size: int) -> Image.Image:
44
+ if isinstance(image_path, str):
45
+ image_pil = Image.open(image_path).convert("RGB")
46
+ else:
47
+ image_pil = image_path.convert("RGB")
48
+
49
+ image_pil = resize_image_pil(image_pil, size)
50
+ return image_pil
51
+
52
+
53
+ def unload_mask(mask):
54
+ mask = mask.cpu().numpy().squeeze()
55
+ mask = mask.astype(np.uint8) * 255
56
+ return Image.fromarray(mask)
57
+
58
+
59
+ def unload_masks(masks):
60
+ return [unload_mask(mask) for mask in masks]
61
+
62
+
63
+ def unload_box(box):
64
+ return box.cpu().numpy().tolist()
65
+
66
+
67
+ def unload_boxes(boxes):
68
+ return [unload_box(box) for box in boxes]
69
+
70
+
71
+ def format_scores(scores):
72
+ return scores.squeeze().cpu().numpy().tolist()
73
+
74
+
75
+ def format_results(labels, scores, boxes, masks):
76
+ results_dict = []
77
+ for row in zip(labels, scores, boxes, masks):
78
+ label, score, box, mask = row
79
+ results_row = dict(label=label, score=score, mask=mask, box=box)
80
+ results_dict.append(results_row)
81
+ return results_dict
hp/visualizer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.patches as patches
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image
5
+ import io
6
+
7
+
8
+ def visualizer(
9
+ image,
10
+ results,
11
+ box_label="box",
12
+ mask_label="mask",
13
+ prompt_label="prompt",
14
+ score_label="score",
15
+ cols=4,
16
+ return_image=False,
17
+ **kwargs,
18
+ ):
19
+ # Convert PIL Image to numpy array
20
+ image_np = np.array(image)
21
+
22
+ # Check image dimensions
23
+ if image_np.ndim != 3:
24
+ raise ValueError("Image must be a 3-dimensional array")
25
+
26
+ # Number of results
27
+ n = len(results)
28
+ rows = (n + cols - 1) // cols # Calculate required number of rows
29
+
30
+ # Setting up the plot
31
+ fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
32
+ if n == 1:
33
+ axs = np.array([[axs]])
34
+ elif rows == 1:
35
+ axs = np.array([axs])
36
+ else:
37
+ axs = axs.reshape(rows, cols)
38
+
39
+ for i, result in enumerate(results):
40
+ label = result[prompt_label]
41
+ score = float(result[score_label])
42
+
43
+ row = i // cols
44
+ col = i % cols
45
+
46
+ # Create a copy of the original image
47
+ combined = image_np.copy()
48
+
49
+ # Draw mask if present
50
+ if mask_label in result:
51
+ mask = result[mask_label]
52
+ # Convert PIL mask to numpy array
53
+ mask_np = np.array(mask)
54
+
55
+ # Check mask dimensions
56
+ if mask_np.ndim != 2:
57
+ raise ValueError("Mask must be a 2-dimensional array")
58
+
59
+ # Create an overlay where mask is True
60
+ overlay = np.zeros_like(image_np)
61
+ overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
62
+
63
+ # Combine the image and the overlay
64
+ indices = np.where(mask_np > 0)
65
+ combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
66
+
67
+ # Show the combined image
68
+ ax = axs[row, col]
69
+ ax.imshow(combined)
70
+ ax.axis("off")
71
+ ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
72
+
73
+ # Draw bounding box if present
74
+ if box_label in result:
75
+ bbox = result[box_label]
76
+ x1, y1, x2, y2 = bbox
77
+ rect = patches.Rectangle(
78
+ (x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor="r", facecolor="none"
79
+ )
80
+ ax.add_patch(rect)
81
+
82
+ # Hide unused subplots if the total number of results is not a multiple of cols
83
+ for idx in range(i + 1, rows * cols):
84
+ row = idx // cols
85
+ col = idx % cols
86
+ axs[row, col].axis("off")
87
+
88
+ plt.tight_layout()
89
+
90
+ if return_image:
91
+ # Save the plot to a bytes buffer
92
+ buf = io.BytesIO()
93
+ plt.savefig(buf, format="png")
94
+ buf.seek(0)
95
+
96
+ # Clear the current figure
97
+ plt.close(fig)
98
+
99
+ # Return the image as a PIL Image object
100
+ return Image.open(buf)
101
+ else:
102
+ plt.show()
hp/yolo_results.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from PIL import Image
3
+ from ultralytics import YOLO
4
+
5
+ from hp.visualizer import visualizer
6
+ from .utils import *
7
+
8
+
9
+ class YOLOResults:
10
+ def __init__(self, image: Union[Image.Image | str], result: List):
11
+ self.image = image
12
+ self.masks = None
13
+ self.boxes = None
14
+ self.scores = None
15
+ self.labels = None
16
+ self.labels_dict = None
17
+ self.result = self.unload(result[0])
18
+ self.formatted_results = format_results(
19
+ self.labels,
20
+ self.scores,
21
+ self.boxes,
22
+ self.masks,
23
+ )
24
+
25
+ def unload(self, result):
26
+ assert (
27
+ bool(result) and hasattr(result, "masks") and hasattr(result, "boxes")
28
+ ), "No Masks or Boxes Found"
29
+ self.masks = unload_masks(result.masks.data)
30
+ self.boxes = unload_boxes(result.boxes.xyxy)
31
+ self.scores = format_scores(result.boxes.conf)
32
+ self.labels = list(result.names.values())
33
+ self.labels_dict = result.names
34
+ det_ids = result.boxes.cls
35
+ det_ids = [int(l.item()) for l in det_ids]
36
+ self.labels = [self.labels_dict[i] for i in det_ids]
37
+
38
+ def visualize(self, return_image=False):
39
+ return visualizer(
40
+ self.image,
41
+ self.formatted_results,
42
+ prompt_label="label",
43
+ return_image=return_image,
44
+ )
sample_images/image_two.jpg CHANGED
yolo-human-parse-v2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1a5e777a3e980c26d70067246e7eb11749294af43fa355ba90af4c3076d849f
3
+ size 13498800
yolo/BodyMask.py DELETED
@@ -1,210 +0,0 @@
1
- import os
2
- from functools import lru_cache
3
- from typing import List
4
-
5
- import cv2
6
- import numpy as np
7
- from diffusers.utils import load_image
8
- from PIL import Image, ImageChops, ImageFilter
9
- from ultralytics import YOLO
10
- from .utils import *
11
-
12
-
13
- def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2):
14
- if not mask:
15
- return None
16
- # Convert PIL image to NumPy array if necessary
17
- if isinstance(mask, Image.Image):
18
- mask = np.array(mask)
19
-
20
- # Ensure mask is in uint8 format
21
- mask = mask.astype(np.uint8)
22
-
23
- # Apply dilation
24
- kernel = np.ones((dilate_factor, dilate_factor), np.uint8)
25
- dilated_mask = cv2.dilate(mask, kernel, iterations=1)
26
-
27
- # Apply erosion for refinement
28
- kernel = np.ones((erosion_factor, erosion_factor), np.uint8)
29
- eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1)
30
-
31
- # Apply Gaussian blur to smooth the edges
32
- blurred_mask = cv2.GaussianBlur(
33
- eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0
34
- )
35
-
36
- # Convert back to PIL image
37
- smoothed_mask = Image.fromarray(blurred_mask).convert("L")
38
-
39
- # Optionally, apply an additional blur for extra smoothness using PIL
40
- smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius))
41
-
42
- return smoothed_mask
43
-
44
-
45
- @lru_cache(maxsize=1)
46
- def get_model(model_id):
47
- model = YOLO(model=model_id)
48
- return model
49
-
50
-
51
- def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image:
52
- """
53
- Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks.
54
-
55
- Parameters:
56
- - masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key.
57
- - labels (List[str]): A list of labels to include in the combination.
58
-
59
- Returns:
60
- - Image.Image: The combined mask as a PIL Image object, or None if no masks are combined.
61
- """
62
- labels_set = set(labels) # Convert labels list to a set for O(1) lookups
63
-
64
- # Filter out any masks that do not have a label key
65
- masks = [mask for mask in masks if "label" in mask]
66
-
67
- # Filter and convert mask images based on the specified labels
68
- mask_images = [
69
- mask["mask"].convert("L")
70
- for mask in masks
71
- if (mask["label"] in labels_set) == is_label
72
- ]
73
-
74
- # Ensure there is at least one mask to combine
75
- if not mask_images:
76
- return None # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.")
77
-
78
- # Initialize the combined mask with the first mask
79
- combined_mask = mask_images[0]
80
-
81
- # Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap
82
- for mask in mask_images[1:]:
83
- combined_mask = ImageChops.lighter(combined_mask, mask)
84
-
85
- return combined_mask
86
-
87
-
88
- body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]
89
-
90
-
91
- class BodyMask:
92
- def __init__(
93
- self,
94
- image_path,
95
- model_id,
96
- labels=body_labels,
97
- overlay="mask",
98
- widen_box=0,
99
- elongate_box=0,
100
- resize_to=640,
101
- dilate_factor=0,
102
- is_label=False,
103
- resize_to_nearest_eight=False,
104
- verbose=True,
105
- remove_overlap=True,
106
- ):
107
- self.image_path = image_path
108
- self.image = self.get_image(
109
- resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight
110
- )
111
- self.labels = labels
112
- self.is_label = is_label
113
- self.model_id = model_id
114
- self.model = get_model(self.model_id)
115
- self.model_labels = self.model.names
116
- self.verbose = verbose
117
- self.results = self.get_results()
118
- self.dilate_factor = dilate_factor
119
- self.body_mask = self.get_body_mask()
120
- self.box = self.get_bounding_box()
121
- self.body_box = self.get_body_box(
122
- remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
123
- )
124
- self.overlay = self.create_overlay(overlay)
125
-
126
- def get_image(self, resize_to, resize_to_nearest_eight):
127
- image = load_image(self.image_path)
128
- if resize_to:
129
- image = resize_preserve_aspect_ratio(image, resize_to)
130
- if resize_to_nearest_eight:
131
- image = resize_image_to_nearest_eight(image)
132
- return image
133
-
134
- def get_results(self):
135
- imgsz = max(self.image.size)
136
- results = self.model(
137
- self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
138
- )[0]
139
- masks, boxes, scores, phrases = unload(results, self.model_labels)
140
- results = format_results(
141
- masks, boxes, scores, phrases, self.model_labels, person_masks_only=False
142
- )
143
- masks_to_filter = ["hair"]
144
- results = filter_highest_score(results, ["hair", "face", "phone"])
145
- return results
146
-
147
- def get_body_mask(self):
148
- body_mask = combine_masks(self.results, self.labels, self.is_label)
149
- if body_mask is not None:
150
- return dilate_mask(body_mask, self.dilate_factor)
151
- return None
152
-
153
- def get_bounding_box(self):
154
- if self.body_mask is None:
155
- return None
156
- return get_bounding_box(self.body_mask)
157
-
158
- def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
159
- if self.body_mask is None or self.box is None:
160
- return None
161
- body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
162
- if remove_overlap and body_box is not None:
163
- body_box = self.remove_overlap(body_box)
164
- return body_box
165
-
166
- def create_overlay(self, overlay_type):
167
- if self.body_box is not None and overlay_type == "box":
168
- return overlay_mask(self.image, self.body_box, opacity=0.9, color="red")
169
- elif self.body_mask is not None:
170
- return overlay_mask(self.image, self.body_mask, opacity=0.9, color="red")
171
- return self.image
172
-
173
- def remove_overlap(self, body_box):
174
- if body_box is None:
175
- return None
176
- box_array = np.array(body_box)
177
- mask = self.combine_masks(mask_labels=self.labels, is_label=True)
178
- if mask is None:
179
- return body_box
180
- mask_array = np.array(mask)
181
- box_array[mask_array == 255] = 0
182
- return Image.fromarray(box_array)
183
-
184
- def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
185
- if not is_label:
186
- mask_labels = [
187
- phrase for phrase in self.phrases if phrase not in mask_labels
188
- ]
189
- masks = [
190
- row.get("mask") for row in self.results if row.get("label") in mask_labels
191
- ]
192
- if len(masks) == 0:
193
- return None
194
- combined_mask = masks[0]
195
- for mask in masks[1:]:
196
- combined_mask = ImageChops.lighter(combined_mask, mask)
197
- return combined_mask
198
-
199
- def display_results(self):
200
- if not self.results:
201
- print("No results to display.")
202
- return
203
- cols = min(len(self.results), 4)
204
- display_image_with_masks(self.image, self.results, cols=cols)
205
-
206
- def get_mask(self, mask_label):
207
- if mask_label not in self.phrases:
208
- print(f"Mask label '{mask_label}' not found in results.")
209
- return None
210
- return [f for f in self.results if f.get("label") == mask_label]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolo/utils.py DELETED
@@ -1,298 +0,0 @@
1
- import matplotlib.patches as patches
2
- import matplotlib.pyplot as plt
3
- import numpy as np
4
- from PIL import Image, ImageDraw
5
-
6
-
7
- def unload_mask(mask):
8
- mask = mask.cpu().numpy().squeeze()
9
- mask = mask.astype(np.uint8) * 255
10
- return Image.fromarray(mask)
11
-
12
-
13
- def unload_box(box):
14
- return box.cpu().numpy().tolist()
15
-
16
-
17
- def masks_overlap(mask1, mask2):
18
- return np.any(np.logical_and(mask1, mask2))
19
-
20
-
21
- def remove_non_person_masks(person_mask, formatted_results):
22
- return [
23
- f
24
- for f in formatted_results
25
- if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask"))
26
- ]
27
-
28
-
29
- def format_masks(masks):
30
- return [unload_mask(mask) for mask in masks]
31
-
32
-
33
- def format_boxes(boxes):
34
- return [unload_box(box) for box in boxes]
35
-
36
-
37
- def format_scores(scores):
38
- return scores.cpu().numpy().tolist()
39
-
40
-
41
- def unload(result, labels_dict):
42
- masks = format_masks(result.masks.data)
43
- boxes = format_boxes(result.boxes.xyxy)
44
- scores = format_scores(result.boxes.conf)
45
- labels = result.boxes.cls
46
- labels = [int(label.item()) for label in labels]
47
- phrases = [labels_dict[label] for label in labels]
48
- return masks, boxes, scores, phrases
49
-
50
-
51
- def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True):
52
- if isinstance(list(labels_dict.keys())[0], int):
53
- labels_dict = {v: k for k, v in labels_dict.items()}
54
-
55
- # check that the person mask is present
56
- if person_masks_only:
57
- assert "person" in labels, "Person mask not present in results"
58
- results_dict = []
59
- for row in zip(labels, scores, boxes, masks):
60
- label, score, box, mask = row
61
- label_id = labels_dict[label]
62
- results_row = dict(
63
- label=label, score=score, mask=mask, box=box, label_id=label_id
64
- )
65
- results_dict.append(results_row)
66
- results_dict = sorted(results_dict, key=lambda x: x["label"])
67
- if person_masks_only:
68
- # Get the person mask
69
- person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"]
70
- assert person_mask is not None, "Person mask not found in results"
71
-
72
- # Remove any results that do no overlap with the person
73
- results_dict = remove_non_person_masks(person_mask, results_dict)
74
- return results_dict
75
-
76
-
77
- def filter_highest_score(results, labels):
78
- """
79
- Filter results to remove entries with lower scores for specified labels.
80
-
81
- Args:
82
- results (list): List of dictionaries containing 'label', 'score', and other keys.
83
- labels (list): List of labels to filter.
84
-
85
- Returns:
86
- list: Filtered results with only the highest score for each specified label.
87
- """
88
- # Dictionary to keep track of the highest score entry for each label
89
- label_highest = {}
90
-
91
- # First pass: identify the highest score for each label
92
- for result in results:
93
- label = result["label"]
94
- if label in labels:
95
- if (
96
- label not in label_highest
97
- or result["score"] > label_highest[label]["score"]
98
- ):
99
- label_highest[label] = result
100
-
101
- # Second pass: construct the filtered list while preserving the order
102
- filtered_results = []
103
- seen_labels = set()
104
-
105
- for result in results:
106
- label = result["label"]
107
- if label in labels:
108
- if label in seen_labels:
109
- continue
110
- if result == label_highest[label]:
111
- filtered_results.append(result)
112
- seen_labels.add(label)
113
- else:
114
- filtered_results.append(result)
115
-
116
- return filtered_results
117
-
118
-
119
- def display_image_with_masks(image, results, cols=4, return_images=False):
120
- # Convert PIL Image to numpy array
121
- image_np = np.array(image)
122
-
123
- # Check image dimensions
124
- if image_np.ndim != 3 or image_np.shape[2] != 3:
125
- raise ValueError("Image must be a 3-dimensional array with 3 color channels")
126
-
127
- # Number of masks
128
- n = len(results)
129
- rows = (n + cols - 1) // cols # Calculate required number of rows
130
-
131
- # Setting up the plot
132
- fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
133
- axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
134
- for i, result in enumerate(results):
135
- mask = result["mask"]
136
- label = result["label"]
137
- score = float(result["score"])
138
-
139
- # Convert PIL mask to numpy array and resize if necessary
140
- mask_np = np.array(mask)
141
- if mask_np.shape != image_np.shape[:2]:
142
- mask_np = resize(
143
- mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
144
- )
145
- mask_np = (mask_np > 0.5).astype(
146
- np.uint8
147
- ) # Threshold back to binary after resize
148
-
149
- # Create an overlay where mask is True
150
- overlay = np.zeros_like(image_np)
151
- overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
152
-
153
- # Combine the image and the overlay
154
- combined = image_np.copy()
155
- indices = np.where(mask_np > 0)
156
- combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
157
-
158
- # Show the combined image
159
- ax = axs[i]
160
- ax.imshow(combined)
161
- ax.axis("off")
162
- ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
163
- rect = patches.Rectangle(
164
- (0, 0),
165
- image_np.shape[1],
166
- image_np.shape[0],
167
- linewidth=1,
168
- edgecolor="r",
169
- facecolor="none",
170
- )
171
- ax.add_patch(rect)
172
-
173
- # Hide unused subplots if the total number of masks is not a multiple of cols
174
- for idx in range(i + 1, rows * cols):
175
- axs[idx].axis("off")
176
- plt.tight_layout()
177
- plt.show()
178
-
179
-
180
- def get_bounding_box(mask):
181
- if mask is None or not isinstance(mask, np.ndarray):
182
- return None
183
-
184
- # Check if the mask is empty
185
- if mask.size == 0 or np.all(mask == 0):
186
- return None
187
-
188
- # Find the bounding box
189
- rows = np.any(mask, axis=1)
190
- cols = np.any(mask, axis=0)
191
- if not np.any(rows) or not np.any(cols):
192
- return None
193
-
194
- rmin, rmax = np.where(rows)[0][[0, -1]]
195
- cmin, cmax = np.where(cols)[0][[0, -1]]
196
-
197
- return (int(cmin), int(rmin), int(cmax), int(rmax))
198
-
199
-
200
- def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
201
- # Convert the PIL segmentation mask to a NumPy array
202
- mask_array = np.array(segmentation_mask)
203
-
204
- # Find the coordinates of the non-zero pixels
205
- non_zero_y, non_zero_x = np.nonzero(mask_array)
206
-
207
- # Calculate the bounding box coordinates
208
- min_x, max_x = np.min(non_zero_x), np.max(non_zero_x)
209
- min_y, max_y = np.min(non_zero_y), np.max(non_zero_y)
210
-
211
- if widen > 0:
212
- min_x = max(0, min_x - widen)
213
- max_x = min(mask_array.shape[1], max_x + widen)
214
-
215
- if elongate > 0:
216
- min_y = max(0, min_y - elongate)
217
- max_y = min(mask_array.shape[0], max_y + elongate)
218
-
219
- # Create a new blank image for the bounding box mask
220
- bounding_box_mask = Image.new("1", segmentation_mask.size)
221
-
222
- # Draw the filled bounding box on the blank image
223
- draw = ImageDraw.Draw(bounding_box_mask)
224
- draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1)
225
-
226
- return bounding_box_mask
227
-
228
-
229
- colors = {
230
- "blue": (136, 207, 249),
231
- "red": (255, 0, 0),
232
- "green": (0, 255, 0),
233
- "yellow": (255, 255, 0),
234
- "purple": (128, 0, 128),
235
- "cyan": (0, 255, 255),
236
- "magenta": (255, 0, 255),
237
- "orange": (255, 165, 0),
238
- "lime": (50, 205, 50),
239
- "pink": (255, 192, 203),
240
- "brown": (139, 69, 19),
241
- "gray": (128, 128, 128),
242
- "black": (0, 0, 0),
243
- "white": (255, 255, 255),
244
- "gold": (255, 215, 0),
245
- "silver": (192, 192, 192),
246
- "beige": (245, 245, 220),
247
- "navy": (0, 0, 128),
248
- "maroon": (128, 0, 0),
249
- "olive": (128, 128, 0),
250
- }
251
-
252
-
253
- def overlay_mask(image, mask, opacity=0.5, color="blue"):
254
- """
255
- Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image
256
- and color the mask with a low opacity blue with hex #88CFF9.
257
- """
258
- # Convert the boolean mask to an image with alpha channel
259
- alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1")
260
-
261
- # Choose the color
262
- r, g, b = colors[color]
263
-
264
- color_mask = Image.new("RGBA", mask.size, (r, g, b, int(opacity * 255)))
265
- mask_rgba = Image.composite(
266
- color_mask, Image.new("RGBA", mask.size, (0, 0, 0, 0)), alpha
267
- )
268
-
269
- # Create a new RGBA image to overlay the mask on
270
- overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
271
-
272
- # Paste the mask onto the overlay
273
- overlay.paste(mask_rgba, (0, 0))
274
-
275
- # Create a new image to return by blending the original image and the overlay
276
- result = Image.alpha_composite(image.convert("RGBA"), overlay)
277
-
278
- # Convert the result back to the original mode and return it
279
- return result.convert(image.mode)
280
-
281
-
282
- def resize_preserve_aspect_ratio(image, max_side=512):
283
- width, height = image.size
284
- scale = min(max_side / width, max_side / height)
285
- new_width = int(width * scale)
286
- new_height = int(height * scale)
287
- return image.resize((new_width, new_height))
288
-
289
-
290
- def round_to_nearest_eigth(value):
291
- return int((value // 8 * 8))
292
-
293
-
294
- def resize_image_to_nearest_eight(image):
295
- width, height = image.size
296
- width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height)
297
- image = image.resize((width, height))
298
- return image