|
from typing import Dict, Tuple |
|
|
|
import torch |
|
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator |
|
from sam2.build_sam import build_sam2 |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
BOX_PROMPT_MODE = "box prompt" |
|
MASK_GENERATION_MODE = "mask generation" |
|
VIDEO_SEGMENTATION_MODE = "video segmentation" |
|
MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE] |
|
|
|
CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"] |
|
CHECKPOINTS = { |
|
"tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"], |
|
"small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"], |
|
"base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"], |
|
"large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"], |
|
} |
|
|
|
|
|
def load_models( |
|
device: torch.device |
|
) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]: |
|
image_predictors = {} |
|
mask_generators = {} |
|
for key, (config, checkpoint) in CHECKPOINTS.items(): |
|
model = build_sam2(config, checkpoint, device=device) |
|
image_predictors[key] = SAM2ImagePredictor(sam_model=model) |
|
mask_generators[key] = SAM2AutomaticMaskGenerator( |
|
model=model, |
|
points_per_side=32, |
|
points_per_batch=64, |
|
pred_iou_thresh=0.7, |
|
stability_score_thresh=0.92, |
|
stability_score_offset=0.7, |
|
crop_n_layers=1, |
|
box_nms_thresh=0.7, |
|
) |
|
return image_predictors, mask_generators |
|
|