import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np from PIL import Image import io def visualizer( image, results, box_label="box", mask_label="mask", prompt_label="prompt", score_label="score", cols=4, return_image=False, **kwargs, ): # Convert PIL Image to numpy array image_np = np.array(image) # Check image dimensions if image_np.ndim != 3: raise ValueError("Image must be a 3-dimensional array") # Number of results n = len(results) rows = (n + cols - 1) // cols # Calculate required number of rows # Setting up the plot fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows)) if n == 1: axs = np.array([[axs]]) elif rows == 1: axs = np.array([axs]) else: axs = axs.reshape(rows, cols) for i, result in enumerate(results): label = result[prompt_label] score = float(result[score_label]) row = i // cols col = i % cols # Create a copy of the original image combined = image_np.copy() # Draw mask if present if mask_label in result: mask = result[mask_label] # Convert PIL mask to numpy array mask_np = np.array(mask) # Check mask dimensions if mask_np.ndim != 2: raise ValueError("Mask must be a 2-dimensional array") # Create an overlay where mask is True overlay = np.zeros_like(image_np) overlay[mask_np > 0] = [0, 0, 255] # Applying blue color on the mask area # Combine the image and the overlay indices = np.where(mask_np > 0) combined[indices] = combined[indices] * 0.5 + overlay[indices] * 0.5 # Show the combined image ax = axs[row, col] ax.imshow(combined) ax.axis("off") ax.set_title(f"Label: {label}, Score: {score:.2f}", fontsize=12) # Draw bounding box if present if box_label in result: bbox = result[box_label] x1, y1, x2, y2 = bbox rect = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor="r", facecolor="none" ) ax.add_patch(rect) # Hide unused subplots if the total number of results is not a multiple of cols for idx in range(i + 1, rows * cols): row = idx // cols col = idx % cols axs[row, col].axis("off") plt.tight_layout() if return_image: # Save the plot to a bytes buffer buf = io.BytesIO() plt.savefig(buf, format="png") buf.seek(0) # Clear the current figure plt.close(fig) # Return the image as a PIL Image object return Image.open(buf) else: plt.show()