Spaces:
Runtime error
Runtime error
import torch.nn.functional as F | |
from mmengine.model import BaseModel | |
from mmdet.registry import MODELS | |
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig | |
class SAMSegmentor(BaseModel): | |
MASK_THRESHOLD = 0.5 | |
def __init__( | |
self, | |
backbone: ConfigType, | |
neck: ConfigType, | |
prompt_encoder: ConfigType, | |
mask_decoder: ConfigType, | |
data_preprocessor: OptConfigType = None, | |
fpn_neck: OptConfigType = None, | |
init_cfg: OptMultiConfig = None, | |
use_clip_feat: bool = False, | |
use_head_feat: bool = False, | |
use_gt_prompt: bool = False, | |
use_point: bool = False, | |
enable_backbone: bool = False, | |
) -> None: | |
super().__init__(data_preprocessor=data_preprocessor, init_cfg=init_cfg) | |
self.backbone = MODELS.build(backbone) | |
self.neck = MODELS.build(neck) | |
self.pe = MODELS.build(prompt_encoder) | |
self.mask_decoder = MODELS.build(mask_decoder) | |
if fpn_neck is not None: | |
self.fpn_neck = MODELS.build(fpn_neck) | |
else: | |
self.fpn_neck = None | |
self.use_clip_feat = use_clip_feat | |
self.use_head_feat = use_head_feat | |
self.use_gt_prompt = use_gt_prompt | |
self.use_point = use_point | |
self.enable_backbone = enable_backbone | |
def extract_feat(self, inputs): | |
backbone_feat = self.backbone(inputs) | |
neck_feat = self.neck(backbone_feat) | |
if self.fpn_neck is not None: | |
fpn_feat = self.fpn_neck(backbone_feat) | |
else: | |
fpn_feat = None | |
return dict( | |
backbone_feat=backbone_feat, | |
neck_feat=neck_feat, | |
fpn_feat=fpn_feat | |
) | |
def extract_masks(self, feat_cache, prompts): | |
sparse_embed, dense_embed = self.pe( | |
prompts, | |
image_size=(1024, 1024), | |
with_points='point_coords' in prompts, | |
with_bboxes='bboxes' in prompts, | |
) | |
kwargs = dict() | |
if self.enable_backbone: | |
kwargs['backbone_feats'] = feat_cache['backbone_feat'] | |
kwargs['backbone'] = self.backbone | |
kwargs['fpn_feats'] = feat_cache['fpn_feat'] | |
low_res_masks, iou_predictions, cls_pred = self.mask_decoder( | |
image_embeddings=feat_cache['neck_feat'], | |
image_pe=self.pe.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embed, | |
dense_prompt_embeddings=dense_embed, | |
multi_mask_output=False, | |
**kwargs | |
) | |
masks = F.interpolate( | |
low_res_masks, | |
scale_factor=4., | |
mode='bilinear', | |
align_corners=False, | |
) | |
masks = masks.sigmoid() | |
cls_pred = cls_pred.softmax(-1)[..., :-1] | |
return masks.detach().cpu().numpy(), cls_pred.detach().cpu() | |
def forward(self, inputs): | |
return inputs | |