YOLO-Human-Parse / hp /visualizer.py
MnLgt's picture
updated yolo model
95eae85
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io
def visualizer(
image,
results,
box_label="box",
mask_label="mask",
prompt_label="prompt",
score_label="score",
cols=4,
return_image=False,
**kwargs,
):
# Convert PIL Image to numpy array
image_np = np.array(image)
# Check image dimensions
if image_np.ndim != 3:
raise ValueError("Image must be a 3-dimensional array")
# Number of results
n = len(results)
rows = (n + cols - 1) // cols # Calculate required number of rows
# Setting up the plot
fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
if n == 1:
axs = np.array([[axs]])
elif rows == 1:
axs = np.array([axs])
else:
axs = axs.reshape(rows, cols)
for i, result in enumerate(results):
label = result[prompt_label]
score = float(result[score_label])
row = i // cols
col = i % cols
# Create a copy of the original image
combined = image_np.copy()
# Draw mask if present
if mask_label in result:
mask = result[mask_label]
# Convert PIL mask to numpy array
mask_np = np.array(mask)
# Check mask dimensions
if mask_np.ndim != 2:
raise ValueError("Mask must be a 2-dimensional array")
# Create an overlay where mask is True
overlay = np.zeros_like(image_np)
overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area
# Combine the image and the overlay
indices = np.where(mask_np > 0)
combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5
# Show the combined image
ax = axs[row, col]
ax.imshow(combined)
ax.axis("off")
ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12)
# Draw bounding box if present
if box_label in result:
bbox = result[box_label]
x1, y1, x2, y2 = bbox
rect = patches.Rectangle(
(x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor="r", facecolor="none"
)
ax.add_patch(rect)
# Hide unused subplots if the total number of results is not a multiple of cols
for idx in range(i + 1, rows * cols):
row = idx // cols
col = idx % cols
axs[row, col].axis("off")
plt.tight_layout()
if return_image:
# Save the plot to a bytes buffer
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
# Clear the current figure
plt.close(fig)
# Return the image as a PIL Image object
return Image.open(buf)
else:
plt.show()