MnLgt commited on
Commit
d7de9f0
·
1 Parent(s): 587a422

Add application file

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ gradio_cached_examples/
3
+ checkpoint-*
4
+ */example.ipynb
5
+
6
+ *.pyc
README.md CHANGED
@@ -1,12 +1,104 @@
1
  ---
2
- title: YOLO Human Parse
3
- emoji:
4
- colorFrom: pink
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ tags:
4
+ - vision
5
+ - image-classification
6
+ widget:
7
+ - src: >-
8
+ https://huggingface.co/jordandavis/yolo-human-parse/blob/main/sample_images/image_one.jpg
9
+ example_title: Straight ahead
10
+ - src: >-
11
+ Looking back
12
+ example_title: Teapot
13
+ - src: >-
14
+ https://huggingface.co/jordandavis/yolo-human-parse/blob/main/sample_images/image_three.jpg
15
+ example_title: Sweats
16
  ---
17
 
18
+
19
+ # YOLO Segmentation Model for Human Body Parts and Objects
20
+
21
+ This repository contains a fine-tuned YOLO (You Only Look Once) segmentation model designed to detect and segment various human body parts and objects in images.
22
+
23
+ ## Model Overview
24
+
25
+ The model is based on the YOLO architecture and has been fine-tuned to detect and segment the following classes:
26
+
27
+ 0. Hair
28
+ 1. Face
29
+ 2. Neck
30
+ 3. Arm
31
+ 4. Hand
32
+ 5. Back
33
+ 6. Leg
34
+ 7. Foot
35
+ 8. Outfit
36
+ 9. Person
37
+ 10. Phone
38
+
39
+ ## Installation
40
+
41
+ To use this model, you'll need to have the appropriate YOLO framework installed. Please follow these steps:
42
+
43
+ 1. Clone this repository:
44
+ ```
45
+ git clone https://github.com/your-username/yolo-segmentation-human-parts.git
46
+ cd yolo-segmentation-human-parts
47
+ ```
48
+
49
+ 2. Install the required dependencies:
50
+ ```
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ ## Usage
55
+
56
+ To use the model for inference, you can use the following Python script:
57
+
58
+ ```python
59
+ from ultralytics import YOLO
60
+
61
+ # Load the model
62
+ model = YOLO('path/to/your/model.pt')
63
+
64
+ # Perform inference on an image
65
+ results = model('path/to/your/image.jpg')
66
+
67
+ # Process the results
68
+ for result in results:
69
+ boxes = result.boxes # Bounding boxes
70
+ masks = result.masks # Segmentation masks
71
+ # Further processing...
72
+ ```
73
+
74
+ ## Training
75
+
76
+ If you want to further fine-tune the model on your own dataset, please follow these steps:
77
+
78
+ 1. Prepare your dataset in the YOLO format.
79
+ 2. Modify the `data.yaml` file to reflect your dataset structure and classes.
80
+ 3. Run the training script:
81
+ ```
82
+ python train.py --img 640 --batch 16 --epochs 100 --data data.yaml --weights yolov5s-seg.pt
83
+ ```
84
+
85
+ ## Evaluation
86
+
87
+ To evaluate the model's performance on your test set, use:
88
+
89
+ ```
90
+ python val.py --weights path/to/your/model.pt --data data.yaml --task segment
91
+ ```
92
+
93
+ ## Contributing
94
+
95
+ Contributions to improve the model or extend its capabilities are welcome. Please submit a pull request or open an issue to discuss proposed changes.
96
+
97
+ ## License
98
+
99
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
100
+
101
+ ## Acknowledgments
102
+
103
+ - Thanks to the YOLO team for the original implementation.
104
+ - Gratitude to all contributors who helped in fine-tuning and improving this model.
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from ultralytics import YOLO
4
+ from yolo.BodyMask import BodyMask
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib import patches
8
+ from skimage.transform import resize
9
+ from PIL import Image
10
+ import io
11
+
12
+ model_id = os.path.abspath("yolo-human-parse-epoch-125.pt")
13
+
14
+
15
+ def display_image_with_masks(image, results, cols=4):
16
+ # Convert PIL Image to numpy array
17
+ image_np = np.array(image)
18
+
19
+ # Check image dimensions
20
+ if image_np.ndim != 3 or image_np.shape[2] != 3:
21
+ raise ValueError("Image must be a 3-dimensional array with 3 color channels")
22
+
23
+ # Number of masks
24
+ n = len(results)
25
+ rows = (n + cols - 1) // cols # Calculate required number of rows
26
+
27
+ # Setting up the plot
28
+ fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
29
+ axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
30
+
31
+ for i, result in enumerate(results):
32
+ mask = result["mask"]
33
+ label = result["label"]
34
+ score = float(result["score"])
35
+
36
+ # Convert PIL mask to numpy array and resize if necessary
37
+ mask_np = np.array(mask)
38
+ if mask_np.shape != image_np.shape[:2]:
39
+ mask_np = resize(
40
+ mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
41
+ )
42
+ mask_np = (mask_np > 0.5).astype(
43
+ np.uint8
44
+ ) # Threshold back to binary after resize
45
+
46
+ # Create an overlay where mask is True
47
+ overlay = np.zeros_like(image_np)
48
+ overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
49
+
50
+ # Combine the image and the overlay
51
+ combined = image_np.copy()
52
+ indices = np.where(mask_np > 0)
53
+ combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
54
+
55
+ # Show the combined image
56
+ ax = axs[i]
57
+ ax.imshow(combined)
58
+ ax.axis("off")
59
+ ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
60
+ rect = patches.Rectangle(
61
+ (0, 0),
62
+ image_np.shape[1],
63
+ image_np.shape[0],
64
+ linewidth=1,
65
+ edgecolor="r",
66
+ facecolor="none",
67
+ )
68
+ ax.add_patch(rect)
69
+
70
+ # Hide unused subplots if the total number of masks is not a multiple of cols
71
+ for idx in range(i + 1, rows * cols):
72
+ axs[idx].axis("off")
73
+
74
+ plt.tight_layout()
75
+
76
+ # Save the plot to a bytes buffer
77
+ buf = io.BytesIO()
78
+ plt.savefig(buf, format="png")
79
+ buf.seek(0)
80
+
81
+ # Clear the current figure
82
+ plt.close(fig)
83
+
84
+ return buf
85
+
86
+
87
+ def perform_segmentation(input_image):
88
+ bm = BodyMask(input_image, model_id=model_id, resize_to=640)
89
+ results = bm.results
90
+ buf = display_image_with_masks(input_image, results)
91
+
92
+ # Convert BytesIO to PIL Image
93
+ img = Image.open(buf)
94
+ return img
95
+
96
+
97
+ # Get example images
98
+ example_images = [
99
+ os.path.join("sample_images", f)
100
+ for f in os.listdir("sample_images")
101
+ if f.endswith((".png", ".jpg", ".jpeg"))
102
+ ]
103
+
104
+ with gr.Blocks() as demo:
105
+ gr.Markdown("# YOLO Segmentation Demo with BodyMask")
106
+ gr.Markdown(
107
+ "Upload an image or select an example to see the YOLO segmentation results."
108
+ )
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ input_image = gr.Image(type="pil", label="Input Image", height=512)
113
+ segment_button = gr.Button("Perform Segmentation")
114
+
115
+ output_image = gr.Image(label="Segmentation Result")
116
+
117
+ gr.Examples(
118
+ examples=example_images,
119
+ inputs=input_image,
120
+ outputs=output_image,
121
+ fn=perform_segmentation,
122
+ cache_examples=True,
123
+ )
124
+
125
+ segment_button.click(
126
+ fn=perform_segmentation,
127
+ inputs=input_image,
128
+ outputs=output_image,
129
+ )
130
+
131
+ demo.launch()
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "input_size": 640,
3
+ "task": "segment"
4
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers==0.30.3
2
+ gradio==4.44.0
3
+ matplotlib==3.8.4
4
+ numpy==1.26.4
5
+ Pillow==10.4.0
6
+ skimage==0.0
7
+ ultralytics==8.2.97
sample_images/image_five.jpg ADDED
sample_images/image_four.jpg ADDED
sample_images/image_one.jpg ADDED
sample_images/image_six.jpg ADDED
sample_images/image_three.jpg ADDED
sample_images/image_two.jpg ADDED
yolo-human-parse-epoch-125.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78215ccb99e5185249f41d03c17abef91f41dae3be2dd66f9633303856ed702
3
+ size 431332491
yolo/BodyMask.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from diffusers.utils import load_image
8
+ from PIL import Image, ImageChops, ImageFilter
9
+ from ultralytics import YOLO
10
+ from .utils import *
11
+
12
+
13
+ def dilate_mask(mask, dilate_factor=6, blur_radius=2, erosion_factor=2):
14
+ if not mask:
15
+ return None
16
+ # Convert PIL image to NumPy array if necessary
17
+ if isinstance(mask, Image.Image):
18
+ mask = np.array(mask)
19
+
20
+ # Ensure mask is in uint8 format
21
+ mask = mask.astype(np.uint8)
22
+
23
+ # Apply dilation
24
+ kernel = np.ones((dilate_factor, dilate_factor), np.uint8)
25
+ dilated_mask = cv2.dilate(mask, kernel, iterations=1)
26
+
27
+ # Apply erosion for refinement
28
+ kernel = np.ones((erosion_factor, erosion_factor), np.uint8)
29
+ eroded_mask = cv2.erode(dilated_mask, kernel, iterations=1)
30
+
31
+ # Apply Gaussian blur to smooth the edges
32
+ blurred_mask = cv2.GaussianBlur(
33
+ eroded_mask, (2 * blur_radius + 1, 2 * blur_radius + 1), 0
34
+ )
35
+
36
+ # Convert back to PIL image
37
+ smoothed_mask = Image.fromarray(blurred_mask).convert("L")
38
+
39
+ # Optionally, apply an additional blur for extra smoothness using PIL
40
+ smoothed_mask = smoothed_mask.filter(ImageFilter.GaussianBlur(radius=blur_radius))
41
+
42
+ return smoothed_mask
43
+
44
+
45
+ @lru_cache(maxsize=1)
46
+ def get_model(model_id):
47
+ model = YOLO(model=model_id)
48
+ return model
49
+
50
+
51
+ def combine_masks(masks: List[dict], labels: List[str], is_label=True) -> Image.Image:
52
+ """
53
+ Combine masks with the specified labels into a single mask, optimized for speed and non-overlapping of excluded masks.
54
+
55
+ Parameters:
56
+ - masks (List[dict]): A list of dictionaries, each containing the mask under a 'mask' key and its label under a 'label' key.
57
+ - labels (List[str]): A list of labels to include in the combination.
58
+
59
+ Returns:
60
+ - Image.Image: The combined mask as a PIL Image object, or None if no masks are combined.
61
+ """
62
+ labels_set = set(labels) # Convert labels list to a set for O(1) lookups
63
+
64
+ # Filter and convert mask images based on the specified labels
65
+ mask_images = [
66
+ mask["mask"].convert("L")
67
+ for mask in masks
68
+ if (mask["label"] in labels_set) == is_label
69
+ ]
70
+
71
+ # Ensure there is at least one mask to combine
72
+ if not mask_images:
73
+ return None # Or raise an appropriate error, e.g., ValueError("No masks found for the specified labels.")
74
+
75
+ # Initialize the combined mask with the first mask
76
+ combined_mask = mask_images[0]
77
+
78
+ # Combine the remaining masks with the existing combined_mask using a bitwise OR operation to ensure non-overlap
79
+ for mask in mask_images[1:]:
80
+ combined_mask = ImageChops.lighter(combined_mask, mask)
81
+
82
+ return combined_mask
83
+
84
+
85
+ body_labels = ["hair", "face", "arm", "hand", "leg", "foot", "outfit"]
86
+
87
+
88
+ class BodyMask:
89
+
90
+ def __init__(
91
+ self,
92
+ image_path,
93
+ model_id,
94
+ labels=body_labels,
95
+ overlay="mask",
96
+ widen_box=0,
97
+ elongate_box=0,
98
+ resize_to=640,
99
+ dilate_factor=0,
100
+ is_label=False,
101
+ resize_to_nearest_eight=False,
102
+ verbose=True,
103
+ remove_overlap=True,
104
+ ):
105
+ self.image_path = image_path
106
+ self.image = self.get_image(
107
+ resize_to=resize_to, resize_to_nearest_eight=resize_to_nearest_eight
108
+ )
109
+ self.labels = labels
110
+ self.is_label = is_label
111
+ self.model_id = model_id
112
+ self.model = get_model(self.model_id)
113
+ self.model_labels = self.model.names
114
+ self.verbose = verbose
115
+ self.results = self.get_results()
116
+ self.dilate_factor = dilate_factor
117
+ self.body_mask = self.get_body_mask()
118
+ self.box = get_bounding_box(self.body_mask)
119
+ self.body_box = self.get_body_box(
120
+ remove_overlap=remove_overlap, widen=widen_box, elongate=elongate_box
121
+ )
122
+ if overlay == "box":
123
+ self.overlay = overlay_mask(
124
+ self.image, self.body_box, opacity=0.9, color="red"
125
+ )
126
+ else:
127
+ self.overlay = overlay_mask(
128
+ self.image, self.body_mask, opacity=0.9, color="red"
129
+ )
130
+
131
+ def get_image(self, resize_to, resize_to_nearest_eight):
132
+ image = load_image(self.image_path)
133
+ if resize_to:
134
+ image = resize_preserve_aspect_ratio(image, resize_to)
135
+ if resize_to_nearest_eight:
136
+ image = resize_image_to_nearest_eight(image)
137
+ else:
138
+ image = image
139
+ return image
140
+
141
+ def get_body_mask(self):
142
+ body_mask = combine_masks(self.results, self.labels, self.is_label)
143
+ return dilate_mask(body_mask, self.dilate_factor)
144
+
145
+ def get_results(self):
146
+ imgsz = max(self.image.size)
147
+ results = self.model(
148
+ self.image, retina_masks=True, imgsz=imgsz, verbose=self.verbose
149
+ )[0]
150
+ self.masks, self.boxes, self.scores, self.phrases = unload(
151
+ results, self.model_labels
152
+ )
153
+ results = format_results(
154
+ self.masks,
155
+ self.boxes,
156
+ self.scores,
157
+ self.phrases,
158
+ self.model_labels,
159
+ person_masks_only=False,
160
+ )
161
+
162
+ # filter out lower score results
163
+ masks_to_filter = ["hair"]
164
+ results = filter_highest_score(results, ["hair", "face", "phone"])
165
+ return results
166
+
167
+ def display_results(self):
168
+ if len(self.masks) < 4:
169
+ cols = len(self.masks)
170
+ else:
171
+ cols = 4
172
+ display_image_with_masks(self.image, self.results, cols=cols)
173
+
174
+ def get_mask(self, mask_label):
175
+ assert mask_label in self.phrases, "Mask label not found in results"
176
+ return [f for f in self.results if f.get("label") == mask_label]
177
+
178
+ def combine_masks(self, mask_labels: List, no_labels=None, is_label=True):
179
+ """
180
+ Combine the masks included in the labels list or all of the masks not in the list
181
+ """
182
+ if not is_label:
183
+ mask_labels = [
184
+ phrase for phrase in self.phrases if phrase not in mask_labels
185
+ ]
186
+ masks = [
187
+ row.get("mask") for row in self.results if row.get("label") in mask_labels
188
+ ]
189
+ if len(masks) == 0:
190
+ return None
191
+ combined_mask = masks[0]
192
+ for mask in masks[1:]:
193
+ combined_mask = ImageChops.lighter(combined_mask, mask)
194
+ return combined_mask
195
+
196
+ def get_body_box(self, remove_overlap=True, widen=0, elongate=0):
197
+ body_box = get_bounding_box_mask(self.body_mask, widen=widen, elongate=elongate)
198
+ if remove_overlap:
199
+ body_box = self.remove_overlap(body_box)
200
+ return body_box
201
+
202
+ def remove_overlap(self, body_box):
203
+ """
204
+ Remove mask regions that overlap with unwanted labels
205
+ """
206
+ # convert mask to numpy array
207
+ box_array = np.array(body_box)
208
+
209
+ # combine the masks for those labels
210
+ mask = self.combine_masks(mask_labels=self.labels, is_label=True)
211
+
212
+ # convert mask to numpy array
213
+ mask_array = np.array(mask)
214
+
215
+ # where the mask array is white set the box array to black
216
+ box_array[mask_array == 255] = 0
217
+
218
+ # convert the box array to an image
219
+ mask_image = Image.fromarray(box_array)
220
+ return mask_image
221
+
222
+
223
+ if __name__ == "__main__":
224
+ url = "https://sjc1.vultrobjects.com/photo-storage/images/525d1f68-314c-455b-a8b6-f5dc3fa044e4.jpeg"
225
+ image_name = url.split("/")[-1]
226
+ labels = ["face", "hair", "phone", "hand"]
227
+ image = load_image(url)
228
+ image_size = image.size
229
+ # Get the original size of the image
230
+ original_size = image.size
231
+
232
+ # Create body mask
233
+ body_mask = BodyMask(
234
+ image,
235
+ overlay="box",
236
+ labels=labels,
237
+ widen_box=50,
238
+ elongate_box=10,
239
+ dilate_factor=0,
240
+ resize_to=640,
241
+ is_label=False,
242
+ remove_overlap=True,
243
+ verbose=False,
244
+ )
245
+
246
+ # Resize the image back to the original size
247
+ image = body_mask.image.resize(original_size)
248
+ body_mask.body_box.save(image_name)
yolo/utils.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.patches as patches
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+
7
+ def unload_mask(mask):
8
+ mask = mask.cpu().numpy().squeeze()
9
+ mask = mask.astype(np.uint8) * 255
10
+ return Image.fromarray(mask)
11
+
12
+
13
+ def unload_box(box):
14
+ return box.cpu().numpy().tolist()
15
+
16
+
17
+ def masks_overlap(mask1, mask2):
18
+ return np.any(np.logical_and(mask1, mask2))
19
+
20
+
21
+ def remove_non_person_masks(person_mask, formatted_results):
22
+ return [
23
+ f
24
+ for f in formatted_results
25
+ if f.get("label") == "person" or masks_overlap(person_mask, f.get("mask"))
26
+ ]
27
+
28
+
29
+ def format_masks(masks):
30
+ return [unload_mask(mask) for mask in masks]
31
+
32
+
33
+ def format_boxes(boxes):
34
+ return [unload_box(box) for box in boxes]
35
+
36
+
37
+ def format_scores(scores):
38
+ return scores.cpu().numpy().tolist()
39
+
40
+
41
+ def unload(result, labels_dict):
42
+ masks = format_masks(result.masks.data)
43
+ boxes = format_boxes(result.boxes.xyxy)
44
+ scores = format_scores(result.boxes.conf)
45
+ labels = result.boxes.cls
46
+ labels = [int(label.item()) for label in labels]
47
+ phrases = [labels_dict[label] for label in labels]
48
+ return masks, boxes, scores, phrases
49
+
50
+
51
+ def format_results(masks, boxes, scores, labels, labels_dict, person_masks_only=True):
52
+ if isinstance(list(labels_dict.keys())[0], int):
53
+ labels_dict = {v: k for k, v in labels_dict.items()}
54
+
55
+ # check that the person mask is present
56
+ if person_masks_only:
57
+ assert "person" in labels, "Person mask not present in results"
58
+ results_dict = []
59
+ for row in zip(labels, scores, boxes, masks):
60
+ label, score, box, mask = row
61
+ label_id = labels_dict[label]
62
+ results_row = dict(
63
+ label=label, score=score, mask=mask, box=box, label_id=label_id
64
+ )
65
+ results_dict.append(results_row)
66
+ results_dict = sorted(results_dict, key=lambda x: x["label"])
67
+ if person_masks_only:
68
+ # Get the person mask
69
+ person_mask = [f for f in results_dict if f.get("label") == "person"][0]["mask"]
70
+ assert person_mask is not None, "Person mask not found in results"
71
+
72
+ # Remove any results that do no overlap with the person
73
+ results_dict = remove_non_person_masks(person_mask, results_dict)
74
+ return results_dict
75
+
76
+
77
+ def filter_highest_score(results, labels):
78
+ """
79
+ Filter results to remove entries with lower scores for specified labels.
80
+
81
+ Args:
82
+ results (list): List of dictionaries containing 'label', 'score', and other keys.
83
+ labels (list): List of labels to filter.
84
+
85
+ Returns:
86
+ list: Filtered results with only the highest score for each specified label.
87
+ """
88
+ # Dictionary to keep track of the highest score entry for each label
89
+ label_highest = {}
90
+
91
+ # First pass: identify the highest score for each label
92
+ for result in results:
93
+ label = result["label"]
94
+ if label in labels:
95
+ if (
96
+ label not in label_highest
97
+ or result["score"] > label_highest[label]["score"]
98
+ ):
99
+ label_highest[label] = result
100
+
101
+ # Second pass: construct the filtered list while preserving the order
102
+ filtered_results = []
103
+ seen_labels = set()
104
+
105
+ for result in results:
106
+ label = result["label"]
107
+ if label in labels:
108
+ if label in seen_labels:
109
+ continue
110
+ if result == label_highest[label]:
111
+ filtered_results.append(result)
112
+ seen_labels.add(label)
113
+ else:
114
+ filtered_results.append(result)
115
+
116
+ return filtered_results
117
+
118
+
119
+ def display_image_with_masks(image, results, cols=4, return_images=False):
120
+ # Convert PIL Image to numpy array
121
+ image_np = np.array(image)
122
+
123
+ # Check image dimensions
124
+ if image_np.ndim != 3 or image_np.shape[2] != 3:
125
+ raise ValueError("Image must be a 3-dimensional array with 3 color channels")
126
+
127
+ # Number of masks
128
+ n = len(results)
129
+ rows = (n + cols - 1) // cols # Calculate required number of rows
130
+
131
+ # Setting up the plot
132
+ fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
133
+ axs = np.array(axs).reshape(-1) # Flatten axs array for easy indexing
134
+ for i, result in enumerate(results):
135
+ mask = result["mask"]
136
+ label = result["label"]
137
+ score = float(result["score"])
138
+
139
+ # Convert PIL mask to numpy array and resize if necessary
140
+ mask_np = np.array(mask)
141
+ if mask_np.shape != image_np.shape[:2]:
142
+ mask_np = resize(
143
+ mask_np, image_np.shape[:2], mode="constant", anti_aliasing=False
144
+ )
145
+ mask_np = (mask_np > 0.5).astype(
146
+ np.uint8
147
+ ) # Threshold back to binary after resize
148
+
149
+ # Create an overlay where mask is True
150
+ overlay = np.zeros_like(image_np)
151
+ overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
152
+
153
+ # Combine the image and the overlay
154
+ combined = image_np.copy()
155
+ indices = np.where(mask_np > 0)
156
+ combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
157
+
158
+ # Show the combined image
159
+ ax = axs[i]
160
+ ax.imshow(combined)
161
+ ax.axis("off")
162
+ ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
163
+ rect = patches.Rectangle(
164
+ (0, 0),
165
+ image_np.shape[1],
166
+ image_np.shape[0],
167
+ linewidth=1,
168
+ edgecolor="r",
169
+ facecolor="none",
170
+ )
171
+ ax.add_patch(rect)
172
+
173
+ # Hide unused subplots if the total number of masks is not a multiple of cols
174
+ for idx in range(i + 1, rows * cols):
175
+ axs[idx].axis("off")
176
+ plt.tight_layout()
177
+ plt.show()
178
+
179
+
180
+ def get_bounding_box(mask):
181
+ """
182
+ Given a segmentation mask, return the bounding box for the mask object.
183
+ """
184
+ # Find indices where the mask is non-zero
185
+ coords = np.argwhere(mask)
186
+ # Get the minimum and maximum x and y coordinates
187
+ x_min, y_min = np.min(coords, axis=0)
188
+ x_max, y_max = np.max(coords, axis=0)
189
+ # Return the bounding box coordinates
190
+ return (y_min, x_min, y_max, x_max)
191
+
192
+
193
+ def get_bounding_box_mask(segmentation_mask, widen=0, elongate=0):
194
+ # Convert the PIL segmentation mask to a NumPy array
195
+ mask_array = np.array(segmentation_mask)
196
+
197
+ # Find the coordinates of the non-zero pixels
198
+ non_zero_y, non_zero_x = np.nonzero(mask_array)
199
+
200
+ # Calculate the bounding box coordinates
201
+ min_x, max_x = np.min(non_zero_x), np.max(non_zero_x)
202
+ min_y, max_y = np.min(non_zero_y), np.max(non_zero_y)
203
+
204
+ if widen > 0:
205
+ min_x = max(0, min_x - widen)
206
+ max_x = min(mask_array.shape[1], max_x + widen)
207
+
208
+ if elongate > 0:
209
+ min_y = max(0, min_y - elongate)
210
+ max_y = min(mask_array.shape[0], max_y + elongate)
211
+
212
+ # Create a new blank image for the bounding box mask
213
+ bounding_box_mask = Image.new("1", segmentation_mask.size)
214
+
215
+ # Draw the filled bounding box on the blank image
216
+ draw = ImageDraw.Draw(bounding_box_mask)
217
+ draw.rectangle([(min_x, min_y), (max_x, max_y)], fill=1)
218
+
219
+ return bounding_box_mask
220
+
221
+
222
+ colors = {
223
+ "blue": (136, 207, 249),
224
+ "red": (255, 0, 0),
225
+ "green": (0, 255, 0),
226
+ "yellow": (255, 255, 0),
227
+ "purple": (128, 0, 128),
228
+ "cyan": (0, 255, 255),
229
+ "magenta": (255, 0, 255),
230
+ "orange": (255, 165, 0),
231
+ "lime": (50, 205, 50),
232
+ "pink": (255, 192, 203),
233
+ "brown": (139, 69, 19),
234
+ "gray": (128, 128, 128),
235
+ "black": (0, 0, 0),
236
+ "white": (255, 255, 255),
237
+ "gold": (255, 215, 0),
238
+ "silver": (192, 192, 192),
239
+ "beige": (245, 245, 220),
240
+ "navy": (0, 0, 128),
241
+ "maroon": (128, 0, 0),
242
+ "olive": (128, 128, 0),
243
+ }
244
+
245
+
246
+ def overlay_mask(image, mask, opacity=0.5, color="blue"):
247
+ """
248
+ Takes in a PIL image and a PIL boolean image mask. Overlay the mask on the image
249
+ and color the mask with a low opacity blue with hex #88CFF9.
250
+ """
251
+ # Convert the boolean mask to an image with alpha channel
252
+ alpha = mask.convert("L").point(lambda x: 255 if x == 255 else 0, mode="1")
253
+
254
+ # Choose the color
255
+ r, g, b = colors[color]
256
+
257
+ color_mask = Image.new("RGBA", mask.size, (r, g, b, int(opacity * 255)))
258
+ mask_rgba = Image.composite(
259
+ color_mask, Image.new("RGBA", mask.size, (0, 0, 0, 0)), alpha
260
+ )
261
+
262
+ # Create a new RGBA image to overlay the mask on
263
+ overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
264
+
265
+ # Paste the mask onto the overlay
266
+ overlay.paste(mask_rgba, (0, 0))
267
+
268
+ # Create a new image to return by blending the original image and the overlay
269
+ result = Image.alpha_composite(image.convert("RGBA"), overlay)
270
+
271
+ # Convert the result back to the original mode and return it
272
+ return result.convert(image.mode)
273
+
274
+
275
+ def resize_preserve_aspect_ratio(image, max_side=512):
276
+ width, height = image.size
277
+ scale = min(max_side / width, max_side / height)
278
+ new_width = int(width * scale)
279
+ new_height = int(height * scale)
280
+ return image.resize((new_width, new_height))
281
+
282
+
283
+ def round_to_nearest_eigth(value):
284
+ return int((value // 8 * 8))
285
+
286
+
287
+ def resize_image_to_nearest_eight(image):
288
+ width, height = image.size
289
+ width, height = round_to_nearest_eigth(width), round_to_nearest_eigth(height)
290
+ image = image.resize((width, height))
291
+ return image