File size: 2,475 Bytes
24e5d63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import os
import cv2
import json
import numpy as np
from PIL import Image, ImageDraw
from segment_anything import SamPredictor, sam_model_registry
model_type = "vit_b"
if model_type == "vit_h":
sam_checkpoint = "YOUR_PATH/sam_vit_h_4b8939.pth"
# https://huggingface.co/spaces/abhishek/StableSAM/blob/main/sam_vit_h_4b8939.pth
elif model_type == "vit_b":
sam_checkpoint = "YOUR_PATH/sam_vit_b_01ec64.pth"
# https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth
else:
raise ValueError(f"Invalid model type: {model_type}")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
predictor = SamPredictor(sam)
image_folder = 'data/VOI'
seqs = os.listdir(image_folder)
seqs.sort()
trajectory_path = 'data/VOI-GT'
img_with_initial_point = 'data/VOI-initial-point'
if not os.path.exists(img_with_initial_point):
os.makedirs(img_with_initial_point)
# 统计时间
import time
import torch
torch.cuda.synchronize()
start = time.time()
for seq in seqs:
print(f"Processing {seq}, {seqs.index(seq)+1}/{len(seqs)}")
seq_path = os.path.join(image_folder, seq)
images = os.listdir(seq_path)
images.sort()
image_path = os.path.join(seq_path, images[0])
image = Image.open(image_path)
file_path = os.path.join(trajectory_path, seq + '.json')
data = json.load(open(file_path))
init_point = data[0][0][0]
print(init_point)
# Step-1:use initial point to segment the object
predictor.set_image(cv2.imread(image_path))
masks, _, _ = predictor.predict(point_coords=np.array([init_point]), point_labels=np.array([1]), multimask_output=False)
mask = masks[0]
mask_image = Image.fromarray((mask * 255).astype(np.uint8))
# Create a completely transparent mask
red_mask = Image.new('RGBA', image.size, (0, 0, 0, 0))
# Add a red semi-transparent mask only in the segmentation area
mask_image = mask_image.convert("L")
red_mask.paste((255, 0, 0, 100), (0, 0), mask_image)
# Transfer to RGBA mode and merge with mask
combined_image = Image.alpha_composite(image.convert("RGBA"), red_mask)
# Transfer to RGB mode and save
combined_image_rgb = combined_image.convert('RGB')
save_path = os.path.join(img_with_initial_point, seq + '_segmented_sam_b.jpg')
combined_image_rgb.save(save_path)
torch.cuda.synchronize()
end = time.time()
print(f"Time: {end-start}")
print(f"Average time: {(end-start)/len(seqs)}")
|