Spaces:
Runtime error
Runtime error
import copy | |
import os | |
from typing import Literal, Tuple, List, Optional | |
import torch | |
from mmcv.cnn import ConvModule | |
from mmdet.structures.bbox import bbox2roi | |
from mmdet.structures.mask import mask2bbox | |
from torch import nn | |
import torch.nn.functional as F | |
from mmengine import MMLogger | |
from mmengine.model import BaseModule | |
from mmdet.registry import MODELS | |
from ext.sam import MaskDecoder | |
from ext.sam.mask_decoder import MLP as SAMMLP | |
from ext.meta.sam_meta import meta_dict, checkpoint_dict | |
from utils.load_checkpoint import load_checkpoint_with_prefix | |
class OVSAMHead(BaseModule): | |
def __init__( | |
self, | |
model_name: Literal['vit_h', 'vit_l', 'vit_b'] = 'vit_h', | |
with_label_token: bool = False, | |
ov_classifier_name: Optional[str] = None, | |
logit: Optional[float] = None, | |
roi_extractor=None, | |
fix: bool = True, | |
init_cfg=None, | |
cur_mask=1, | |
roi_extractor_single=None, | |
load_roi_conv=None, | |
gen_box=False, | |
): | |
assert init_cfg is not None and \ | |
init_cfg['type'] in ['sam_pretrain', 'Pretrained'], f"{init_cfg['type']} is not supported." | |
pretrained = init_cfg['checkpoint'] | |
super().__init__(init_cfg=None) | |
self.init_cfg = init_cfg | |
self.logger = MMLogger.get_current_instance() | |
if roi_extractor_single is not None: | |
self.roi_extractor_single = MODELS.build(roi_extractor_single) | |
self.roi_merge_proj = nn.Linear(768 * 2, 768) | |
else: | |
self.roi_extractor_single = None | |
self.roi_merge_proj = None | |
mask_decoder = MaskDecoder( | |
num_multimask_outputs=cur_mask - 1, | |
transformer_dim=meta_dict[model_name]['prompt_embed_dim'], | |
iou_head_depth=3, | |
iou_head_hidden_dim=256, | |
with_iou=False | |
) | |
if self.init_cfg['type'] == 'sam_pretrain': | |
raise NotImplementedError | |
self.mask_decoder = mask_decoder | |
self.with_label_token = with_label_token | |
if self.with_label_token: | |
ov_path = os.path.join(os.path.expanduser('./models/'), f"{ov_classifier_name}.pth") | |
cls_embed = torch.load(ov_path) | |
cls_embed_norm = cls_embed.norm(p=2, dim=-1) | |
assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) | |
_dim = cls_embed.size(2) | |
_prototypes = cls_embed.size(1) | |
back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu') | |
cls_embed = torch.cat([ | |
cls_embed, back_token.repeat(_prototypes, 1)[None] | |
], dim=0) | |
self.register_buffer('cls_embed', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) | |
if logit is None: | |
logit_scale = torch.tensor(4.6052, dtype=torch.float32) | |
else: | |
logit_scale = torch.tensor(logit, dtype=torch.float32) | |
self.register_buffer('logit_scale', logit_scale, persistent=False) | |
transformer_dim = self.mask_decoder.mask_tokens.weight.shape[1] | |
self.label_token = nn.Embedding(1, transformer_dim) | |
self.label_mlp = SAMMLP(transformer_dim, transformer_dim, _dim, 3) | |
self.gen_box = gen_box | |
if roi_extractor is not None: | |
self.roi = MODELS.build(roi_extractor) | |
self.roi_conv = nn.Sequential(*[ | |
ConvModule(in_channels=self.roi.out_channels, out_channels=_dim, kernel_size=1, bias=False) | |
]) | |
else: | |
self.roi = None | |
if self.init_cfg['type'] == 'Pretrained': | |
checkpoint_path = pretrained | |
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=self.init_cfg['prefix']) | |
self.load_state_dict(state_dict, strict=True) | |
if roi_extractor is not None and load_roi_conv is not None: | |
checkpoint_path = load_roi_conv['checkpoint'] | |
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=load_roi_conv['prefix']) | |
self.roi_conv.load_state_dict(state_dict, strict=True) | |
self.fix = fix | |
if self.fix: | |
self.train(mode=False) | |
for name, param in self.named_parameters(): | |
param.requires_grad = False | |
def init_weights(self): | |
self.logger.info(f"Init Config for {self.__class__.__name__}") | |
self.logger.info(self.init_cfg) | |
def forward_logit(self, cls_embd): | |
cls_pred = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.cls_embed) | |
cls_pred = cls_pred.max(-1).values | |
cls_pred = self.logit_scale.exp() * cls_pred | |
return cls_pred | |
def predict_masks( | |
self, | |
image_embeddings: torch.Tensor, | |
image_pe: torch.Tensor, | |
sparse_prompt_embeddings: torch.Tensor, | |
dense_prompt_embeddings: torch.Tensor, | |
fpn_feats: List[torch.Tensor], | |
roi_list: Optional[List[torch.Tensor]], | |
backbone_feature: torch.Tensor, | |
backbone=None | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Predicts masks. See 'forward' for more details.""" | |
num_instances = int(sparse_prompt_embeddings.size(0)) | |
# Concatenate output tokens | |
output_tokens = torch.cat([ | |
self.label_token.weight, | |
self.mask_decoder.mask_tokens.weight], dim=0 | |
) | |
output_tokens = output_tokens.unsqueeze(0).expand(num_instances, -1, -1) | |
queries = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) | |
# image_embeddings = torch.repeat_interleave(image_embeddings, num_instances, dim=0) | |
image_embeddings = image_embeddings + dense_prompt_embeddings | |
pos_img = torch.repeat_interleave(image_pe, num_instances, dim=0) | |
b, c, h, w = image_embeddings.shape | |
# Run the transformer | |
queries, mask_feats = self.mask_decoder.transformer(image_embeddings, pos_img, queries) | |
label_query = queries[:, 0, :] | |
mask_embeds = queries[:, 1:(1 + self.mask_decoder.num_mask_tokens), :] | |
# Upscale mask embeddings and predict masks using the mask tokens | |
mask_feats = mask_feats.transpose(1, 2).view(b, c, h, w) | |
mask_feats = self.mask_decoder.output_upscaling(mask_feats) | |
mask_queries_list: List[torch.Tensor] = [] | |
for i in range(self.mask_decoder.num_mask_tokens): | |
mask_queries_list.append(self.mask_decoder.output_hypernetworks_mlps[i](mask_embeds[:, i, :])) | |
mask_queries = torch.stack(mask_queries_list, dim=1) | |
b, c, h, w = mask_feats.shape | |
masks = (mask_queries @ mask_feats.view(b, c, h * w)).view(b, -1, h, w) | |
# Generate class labels | |
if self.with_label_token: | |
cls_embed_list = [] | |
assert self.mask_decoder.num_mask_tokens == 1 | |
for i in range(self.mask_decoder.num_mask_tokens): | |
cls_embed_list.append(self.label_mlp(label_query)) | |
cls_embed = torch.stack(cls_embed_list, dim=1) | |
if self.gen_box: | |
bboxes = mask2bbox(masks.sigmoid()[:, 0] > 0.5) * 4 | |
roi_list = bbox2roi([bboxes]) | |
roi_feats = self.roi(fpn_feats, roi_list) | |
roi_feats = self.roi_conv(roi_feats) | |
roi_feats = roi_feats.mean(dim=-1).mean(dim=-1) | |
if self.roi_extractor_single: | |
roi_feats_clip = self.roi_extractor_single( | |
backbone.get_clip_feature(backbone_feature[-1:]), roi_list | |
) | |
roi_feats_clip = backbone.forward_feat(roi_feats_clip) | |
roi_feats = self.roi_merge_proj(torch.cat([roi_feats, roi_feats_clip], dim=-1)) | |
roi_feats = roi_feats[:, None] + 0 * cls_embed | |
cls_pred = self.forward_logit(roi_feats) | |
else: | |
cls_pred = None | |
return masks, None, cls_pred | |
def forward( | |
self, | |
image_embeddings: torch.Tensor, | |
image_pe: torch.Tensor, | |
sparse_prompt_embeddings: torch.Tensor, | |
dense_prompt_embeddings: torch.Tensor, | |
multi_mask_output: bool, | |
data_samples=None, | |
fpn_feats=None, | |
backbone_feats=None, | |
backbone=None, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: | |
num_prompts = len(sparse_prompt_embeddings) | |
image_embeddings = torch.repeat_interleave(image_embeddings, num_prompts, dim=0) | |
masks, _, cls_pred = self.predict_masks( | |
image_embeddings=image_embeddings, | |
image_pe=image_pe, | |
sparse_prompt_embeddings=sparse_prompt_embeddings, | |
dense_prompt_embeddings=dense_prompt_embeddings, | |
fpn_feats=fpn_feats, | |
roi_list=None, | |
backbone_feature=backbone_feats, | |
backbone=backbone, | |
) | |
# Select the correct mask or masks for output | |
if multi_mask_output: | |
mask_slice = slice(1, None) | |
else: | |
mask_slice = slice(0, 1) | |
masks = masks[:, mask_slice, :, :] | |
# Prepare output | |
return masks, None, cls_pred | |