Update README.md
Browse files
README.md
CHANGED
|
@@ -108,123 +108,134 @@ plt.show()
|
|
| 108 |
|
| 109 |
## Complete Example with Visualization
|
| 110 |
|
| 111 |
-
Here's a complete example showing how to use SAM-HQ with the image embedding workflow and how to visualize the results:
|
| 112 |
-
|
| 113 |
```python
|
| 114 |
-
import torch
|
| 115 |
import numpy as np
|
| 116 |
import matplotlib.pyplot as plt
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
from transformers import SamHQModel, SamHQProcessor
|
| 120 |
|
| 121 |
-
# 1. Load model and processor
|
| 122 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 123 |
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
| 124 |
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
| 125 |
|
| 126 |
-
|
|
|
|
| 127 |
img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
|
| 128 |
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
| 129 |
-
plt.figure(figsize=(10, 10))
|
| 130 |
plt.imshow(raw_image)
|
| 131 |
-
plt.axis('off')
|
| 132 |
-
plt.show()
|
| 133 |
|
| 134 |
-
# 3. Compute image embeddings
|
| 135 |
inputs = processor(raw_image, return_tensors="pt").to(device)
|
| 136 |
image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
| 137 |
|
| 138 |
-
|
| 139 |
-
input_boxes
|
| 140 |
|
| 141 |
-
# Helper function to display bounding box
|
| 142 |
-
def show_box(box, ax):
|
| 143 |
-
x0, y0 = box[0], box[1]
|
| 144 |
-
w, h = box[2] - box[0], box[3] - box[1]
|
| 145 |
-
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
| 146 |
-
|
| 147 |
-
plt.figure(figsize=(10, 10))
|
| 148 |
-
plt.imshow(raw_image)
|
| 149 |
-
for box in input_boxes[0]:
|
| 150 |
-
show_box(box, plt.gca())
|
| 151 |
-
plt.axis('on')
|
| 152 |
-
plt.title("Input Image with Bounding Box")
|
| 153 |
-
plt.show()
|
| 154 |
-
|
| 155 |
-
# 5. Run inference with the bounding box
|
| 156 |
-
# First update the inputs with the image embeddings
|
| 157 |
inputs.pop("pixel_values", None)
|
| 158 |
inputs.update({"image_embeddings": image_embeddings})
|
| 159 |
inputs.update({"intermediate_embeddings": intermediate_embeddings})
|
| 160 |
-
inputs.update({"input_boxes": torch.tensor(input_boxes).to(device)})
|
| 161 |
-
|
| 162 |
-
# Run inference
|
| 163 |
with torch.no_grad():
|
| 164 |
outputs = model(**inputs)
|
| 165 |
-
|
| 166 |
-
# 6. Post-process the masks
|
| 167 |
-
masks = processor.image_processor.post_process_masks(
|
| 168 |
-
outputs.pred_masks.cpu(),
|
| 169 |
-
inputs["original_sizes"].cpu(),
|
| 170 |
-
inputs["reshaped_input_sizes"].cpu()
|
| 171 |
-
)
|
| 172 |
scores = outputs.iou_scores
|
| 173 |
|
| 174 |
-
|
| 175 |
-
# Helper function to show masks
|
| 176 |
-
def show_mask(mask, ax, random_color=False):
|
| 177 |
-
if random_color:
|
| 178 |
-
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 179 |
-
else:
|
| 180 |
-
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 181 |
-
h, w = mask.shape[-2:]
|
| 182 |
-
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 183 |
-
ax.imshow(mask_image)
|
| 184 |
|
| 185 |
-
|
| 186 |
-
if len(masks[0].shape) == 4:
|
| 187 |
-
masks_to_show = masks[0].squeeze()
|
| 188 |
-
else:
|
| 189 |
-
masks_to_show = masks[0]
|
| 190 |
-
|
| 191 |
-
if scores.shape[0] == 1:
|
| 192 |
-
scores_to_show = scores.squeeze()
|
| 193 |
-
else:
|
| 194 |
-
scores_to_show = scores
|
| 195 |
-
|
| 196 |
-
# Create a figure with subplots for each mask
|
| 197 |
-
nb_predictions = scores_to_show.shape[-1]
|
| 198 |
-
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
|
| 199 |
-
|
| 200 |
-
# Handle the case where there's only one mask
|
| 201 |
-
if nb_predictions == 1:
|
| 202 |
-
axes = [axes]
|
| 203 |
-
|
| 204 |
-
for i, (mask, score) in enumerate(zip(masks_to_show, scores_to_show)):
|
| 205 |
-
mask = mask.cpu().detach()
|
| 206 |
-
axes[i].imshow(np.array(raw_image))
|
| 207 |
-
show_mask(mask, axes[i])
|
| 208 |
-
axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
|
| 209 |
-
axes[i].axis("off")
|
| 210 |
-
plt.tight_layout()
|
| 211 |
-
plt.show()
|
| 212 |
-
|
| 213 |
-
# Show all masks overlaid on a single image
|
| 214 |
-
fig, ax = plt.subplots(figsize=(10, 10))
|
| 215 |
-
ax.imshow(np.array(raw_image))
|
| 216 |
-
for i, (mask, score) in enumerate(zip(masks_to_show, scores_to_show)):
|
| 217 |
-
if len(mask.shape) > 2:
|
| 218 |
-
mask = mask.squeeze()
|
| 219 |
-
show_mask(mask, ax, random_color=True)
|
| 220 |
-
ax.set_title("All Masks Overlaid")
|
| 221 |
-
ax.axis("off")
|
| 222 |
-
plt.tight_layout()
|
| 223 |
-
plt.show()
|
| 224 |
```
|
| 225 |
|
| 226 |
-
This example demonstrates the complete workflow of using SAM-HQ with the "sushmanth/sam_hq_vit_b" model. It computes image embeddings once and then uses them for inference with a bounding box prompt. The resulting masks are visualized both individually with their confidence scores and overlaid on a single image with different colors.
|
| 227 |
-
|
| 228 |
# Citation
|
| 229 |
|
| 230 |
```
|
|
|
|
| 108 |
|
| 109 |
## Complete Example with Visualization
|
| 110 |
|
|
|
|
|
|
|
| 111 |
```python
|
|
|
|
| 112 |
import numpy as np
|
| 113 |
import matplotlib.pyplot as plt
|
| 114 |
+
def show_mask(mask, ax, random_color=False):
|
| 115 |
+
if random_color:
|
| 116 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 117 |
+
else:
|
| 118 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 119 |
+
h, w = mask.shape[-2:]
|
| 120 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 121 |
+
ax.imshow(mask_image)
|
| 122 |
+
def show_box(box, ax):
|
| 123 |
+
x0, y0 = box[0], box[1]
|
| 124 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 125 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
|
| 126 |
+
def show_boxes_on_image(raw_image, boxes):
|
| 127 |
+
plt.figure(figsize=(10,10))
|
| 128 |
+
plt.imshow(raw_image)
|
| 129 |
+
for box in boxes:
|
| 130 |
+
show_box(box, plt.gca())
|
| 131 |
+
plt.axis('on')
|
| 132 |
+
plt.show()
|
| 133 |
+
def show_points_on_image(raw_image, input_points, input_labels=None):
|
| 134 |
+
plt.figure(figsize=(10,10))
|
| 135 |
+
plt.imshow(raw_image)
|
| 136 |
+
input_points = np.array(input_points)
|
| 137 |
+
if input_labels is None:
|
| 138 |
+
labels = np.ones_like(input_points[:, 0])
|
| 139 |
+
else:
|
| 140 |
+
labels = np.array(input_labels)
|
| 141 |
+
show_points(input_points, labels, plt.gca())
|
| 142 |
+
plt.axis('on')
|
| 143 |
+
plt.show()
|
| 144 |
+
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
| 145 |
+
plt.figure(figsize=(10,10))
|
| 146 |
+
plt.imshow(raw_image)
|
| 147 |
+
input_points = np.array(input_points)
|
| 148 |
+
if input_labels is None:
|
| 149 |
+
labels = np.ones_like(input_points[:, 0])
|
| 150 |
+
else:
|
| 151 |
+
labels = np.array(input_labels)
|
| 152 |
+
show_points(input_points, labels, plt.gca())
|
| 153 |
+
for box in boxes:
|
| 154 |
+
show_box(box, plt.gca())
|
| 155 |
+
plt.axis('on')
|
| 156 |
+
plt.show()
|
| 157 |
+
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
|
| 158 |
+
plt.figure(figsize=(10,10))
|
| 159 |
+
plt.imshow(raw_image)
|
| 160 |
+
input_points = np.array(input_points)
|
| 161 |
+
if input_labels is None:
|
| 162 |
+
labels = np.ones_like(input_points[:, 0])
|
| 163 |
+
else:
|
| 164 |
+
labels = np.array(input_labels)
|
| 165 |
+
show_points(input_points, labels, plt.gca())
|
| 166 |
+
for box in boxes:
|
| 167 |
+
show_box(box, plt.gca())
|
| 168 |
+
plt.axis('on')
|
| 169 |
+
plt.show()
|
| 170 |
+
def show_points(coords, labels, ax, marker_size=375):
|
| 171 |
+
pos_points = coords[labels==1]
|
| 172 |
+
neg_points = coords[labels==0]
|
| 173 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
| 174 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
| 175 |
+
def show_masks_on_image(raw_image, masks, scores):
|
| 176 |
+
if len(masks.shape) == 4:
|
| 177 |
+
masks = masks.squeeze()
|
| 178 |
+
if scores.shape[0] == 1:
|
| 179 |
+
scores = scores.squeeze()
|
| 180 |
+
nb_predictions = scores.shape[-1]
|
| 181 |
+
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
|
| 182 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 183 |
+
mask = mask.cpu().detach()
|
| 184 |
+
axes[i].imshow(np.array(raw_image))
|
| 185 |
+
show_mask(mask, axes[i])
|
| 186 |
+
axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
|
| 187 |
+
axes[i].axis("off")
|
| 188 |
+
plt.show()
|
| 189 |
+
def show_masks_on_single_image(raw_image, masks, scores):
|
| 190 |
+
if len(masks.shape) == 4:
|
| 191 |
+
masks = masks.squeeze()
|
| 192 |
+
if scores.shape[0] == 1:
|
| 193 |
+
scores = scores.squeeze()
|
| 194 |
+
# Convert image to numpy array if it's not already
|
| 195 |
+
image_np = np.array(raw_image)
|
| 196 |
+
# Create a figure
|
| 197 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
| 198 |
+
ax.imshow(image_np)
|
| 199 |
+
# Overlay all masks on the same image
|
| 200 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 201 |
+
mask = mask.cpu().detach().numpy() # Convert to NumPy
|
| 202 |
+
show_mask(mask, ax) # Assuming `show_mask` properly overlays the mask
|
| 203 |
+
ax.set_title(f"Overlayed Masks with Scores")
|
| 204 |
+
ax.axis("off")
|
| 205 |
+
plt.show()
|
| 206 |
+
|
| 207 |
+
import torch
|
| 208 |
from transformers import SamHQModel, SamHQProcessor
|
| 209 |
|
|
|
|
| 210 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 211 |
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
|
| 212 |
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
|
| 213 |
|
| 214 |
+
from PIL import Image
|
| 215 |
+
import requests
|
| 216 |
img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
|
| 217 |
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
|
|
|
| 218 |
plt.imshow(raw_image)
|
|
|
|
|
|
|
| 219 |
|
|
|
|
| 220 |
inputs = processor(raw_image, return_tensors="pt").to(device)
|
| 221 |
image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
|
| 222 |
|
| 223 |
+
input_boxes = [[[306, 132, 925, 893]]]
|
| 224 |
+
show_boxes_on_image(raw_image, input_boxes[0])
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
inputs.pop("pixel_values", None)
|
| 227 |
inputs.update({"image_embeddings": image_embeddings})
|
| 228 |
inputs.update({"intermediate_embeddings": intermediate_embeddings})
|
|
|
|
|
|
|
|
|
|
| 229 |
with torch.no_grad():
|
| 230 |
outputs = model(**inputs)
|
| 231 |
+
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
scores = outputs.iou_scores
|
| 233 |
|
| 234 |
+
show_masks_on_single_image(raw_image, masks[0], scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
show_masks_on_image(raw_image, masks[0], scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
```
|
| 238 |
|
|
|
|
|
|
|
| 239 |
# Citation
|
| 240 |
|
| 241 |
```
|