|
import os |
|
import cv2 |
|
import json |
|
import numpy as np |
|
|
|
from PIL import Image, ImageDraw |
|
from segment_anything import SamPredictor, sam_model_registry |
|
|
|
model_type = "vit_b" |
|
if model_type == "vit_h": |
|
sam_checkpoint = "YOUR_PATH/sam_vit_h_4b8939.pth" |
|
|
|
elif model_type == "vit_b": |
|
sam_checkpoint = "YOUR_PATH/sam_vit_b_01ec64.pth" |
|
|
|
else: |
|
raise ValueError(f"Invalid model type: {model_type}") |
|
|
|
|
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
predictor = SamPredictor(sam) |
|
|
|
|
|
image_folder = 'data/VOI' |
|
seqs = os.listdir(image_folder) |
|
seqs.sort() |
|
|
|
trajectory_path = 'data/VOI-GT' |
|
|
|
img_with_initial_point = 'data/VOI-initial-point' |
|
if not os.path.exists(img_with_initial_point): |
|
os.makedirs(img_with_initial_point) |
|
|
|
|
|
import time |
|
import torch |
|
torch.cuda.synchronize() |
|
start = time.time() |
|
for seq in seqs: |
|
print(f"Processing {seq}, {seqs.index(seq)+1}/{len(seqs)}") |
|
seq_path = os.path.join(image_folder, seq) |
|
images = os.listdir(seq_path) |
|
images.sort() |
|
image_path = os.path.join(seq_path, images[0]) |
|
image = Image.open(image_path) |
|
|
|
file_path = os.path.join(trajectory_path, seq + '.json') |
|
data = json.load(open(file_path)) |
|
init_point = data[0][0][0] |
|
print(init_point) |
|
|
|
|
|
|
|
predictor.set_image(cv2.imread(image_path)) |
|
masks, _, _ = predictor.predict(point_coords=np.array([init_point]), point_labels=np.array([1]), multimask_output=False) |
|
|
|
mask = masks[0] |
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
|
|
|
red_mask = Image.new('RGBA', image.size, (0, 0, 0, 0)) |
|
|
|
|
|
mask_image = mask_image.convert("L") |
|
red_mask.paste((255, 0, 0, 100), (0, 0), mask_image) |
|
|
|
|
|
combined_image = Image.alpha_composite(image.convert("RGBA"), red_mask) |
|
|
|
|
|
combined_image_rgb = combined_image.convert('RGB') |
|
save_path = os.path.join(img_with_initial_point, seq + '_segmented_sam_b.jpg') |
|
combined_image_rgb.save(save_path) |
|
|
|
|
|
torch.cuda.synchronize() |
|
end = time.time() |
|
print(f"Time: {end-start}") |
|
print(f"Average time: {(end-start)/len(seqs)}") |
|
|