chongzhou's picture
push app.py to huggingface
ddaab9d
import torch
import torch.nn as nn
from ..modeling import Sam
from .amg import calculate_stability_score
class SamCoreMLModel(nn.Module):
"""
This model should not be called directly, but is used in CoreML export.
"""
def __init__(
self,
model: Sam,
use_stability_score: bool = False
) -> None:
super().__init__()
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
point_labels == -1
)
for i in range(self.model.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
i
].weight * (point_labels == i)
return point_embedding
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
)
if self.use_stability_score:
scores = calculate_stability_score(
masks, self.model.mask_threshold, self.stability_score_offset
)
return scores, masks