Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from ..modeling import Sam | |
from .amg import calculate_stability_score | |
class SamCoreMLModel(nn.Module): | |
""" | |
This model should not be called directly, but is used in CoreML export. | |
""" | |
def __init__( | |
self, | |
model: Sam, | |
use_stability_score: bool = False | |
) -> None: | |
super().__init__() | |
self.mask_decoder = model.mask_decoder | |
self.model = model | |
self.img_size = model.image_encoder.img_size | |
self.use_stability_score = use_stability_score | |
self.stability_score_offset = 1.0 | |
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: | |
point_coords = point_coords + 0.5 | |
point_coords = point_coords / self.img_size | |
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) | |
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) | |
point_embedding = point_embedding * (point_labels != -1) | |
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( | |
point_labels == -1 | |
) | |
for i in range(self.model.prompt_encoder.num_point_embeddings): | |
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ | |
i | |
].weight * (point_labels == i) | |
return point_embedding | |
def forward( | |
self, | |
image_embeddings: torch.Tensor, | |
point_coords: torch.Tensor, | |
point_labels: torch.Tensor, | |
): | |
sparse_embedding = self._embed_points(point_coords, point_labels) | |
dense_embedding = self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) | |
masks, scores = self.model.mask_decoder.predict_masks( | |
image_embeddings=image_embeddings, | |
image_pe=self.model.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embedding, | |
dense_prompt_embeddings=dense_embedding, | |
) | |
if self.use_stability_score: | |
scores = calculate_stability_score( | |
masks, self.model.mask_threshold, self.stability_score_offset | |
) | |
return scores, masks | |