| import base64 | |
| from io import BytesIO | |
| from time import perf_counter | |
| from typing import Any, List, Optional, Union | |
| import numpy as np | |
| import onnxruntime | |
| import rasterio.features | |
| import torch | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from shapely.geometry import Polygon as ShapelyPolygon | |
| from inference.core.entities.requests.inference import InferenceRequestImage | |
| from inference.core.entities.requests.sam import ( | |
| SamEmbeddingRequest, | |
| SamInferenceRequest, | |
| SamSegmentationRequest, | |
| ) | |
| from inference.core.entities.responses.sam import ( | |
| SamEmbeddingResponse, | |
| SamSegmentationResponse, | |
| ) | |
| from inference.core.env import SAM_MAX_EMBEDDING_CACHE_SIZE, SAM_VERSION_ID | |
| from inference.core.models.roboflow import RoboflowCoreModel | |
| from inference.core.utils.image_utils import load_image_rgb | |
| from inference.core.utils.postprocess import masks2poly | |
| class SegmentAnything(RoboflowCoreModel): | |
| """SegmentAnything class for handling segmentation tasks. | |
| Attributes: | |
| sam: The segmentation model. | |
| predictor: The predictor for the segmentation model. | |
| ort_session: ONNX runtime inference session. | |
| embedding_cache: Cache for embeddings. | |
| image_size_cache: Cache for image sizes. | |
| embedding_cache_keys: Keys for the embedding cache. | |
| low_res_logits_cache: Cache for low resolution logits. | |
| segmentation_cache_keys: Keys for the segmentation cache. | |
| """ | |
| def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs): | |
| """Initializes the SegmentAnything. | |
| Args: | |
| *args: Variable length argument list. | |
| **kwargs: Arbitrary keyword arguments. | |
| """ | |
| super().__init__(*args, model_id=model_id, **kwargs) | |
| self.sam = sam_model_registry[self.version_id]( | |
| checkpoint=self.cache_file("encoder.pth") | |
| ) | |
| self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu") | |
| self.predictor = SamPredictor(self.sam) | |
| self.ort_session = onnxruntime.InferenceSession( | |
| self.cache_file("decoder.onnx"), | |
| providers=[ | |
| "CUDAExecutionProvider", | |
| "CPUExecutionProvider", | |
| ], | |
| ) | |
| self.embedding_cache = {} | |
| self.image_size_cache = {} | |
| self.embedding_cache_keys = [] | |
| self.low_res_logits_cache = {} | |
| self.segmentation_cache_keys = [] | |
| self.task_type = "unsupervised-segmentation" | |
| def get_infer_bucket_file_list(self) -> List[str]: | |
| """Gets the list of files required for inference. | |
| Returns: | |
| List[str]: List of file names. | |
| """ | |
| return ["encoder.pth", "decoder.onnx"] | |
| def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs): | |
| """ | |
| Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, | |
| the cached result will be returned. | |
| Args: | |
| image (Any): The image to be embedded. The format should be compatible with the preproc_image method. | |
| image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached | |
| with this ID. Defaults to None. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image | |
| and the second element is the shape (height, width) of the processed image. | |
| Notes: | |
| - Embeddings and image sizes are cached to improve performance on repeated requests for the same image. | |
| - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, | |
| the oldest entries are removed. | |
| Example: | |
| >>> img_array = ... # some image array | |
| >>> embed_image(img_array, image_id="sample123") | |
| (array([...]), (224, 224)) | |
| """ | |
| if image_id and image_id in self.embedding_cache: | |
| return ( | |
| self.embedding_cache[image_id], | |
| self.image_size_cache[image_id], | |
| ) | |
| img_in = self.preproc_image(image) | |
| self.predictor.set_image(img_in) | |
| embedding = self.predictor.get_image_embedding().cpu().numpy() | |
| if image_id: | |
| self.embedding_cache[image_id] = embedding | |
| self.image_size_cache[image_id] = img_in.shape[:2] | |
| self.embedding_cache_keys.append(image_id) | |
| if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: | |
| cache_key = self.embedding_cache_keys.pop(0) | |
| del self.embedding_cache[cache_key] | |
| del self.image_size_cache[cache_key] | |
| return (embedding, img_in.shape[:2]) | |
| def infer_from_request(self, request: SamInferenceRequest): | |
| """Performs inference based on the request type. | |
| Args: | |
| request (SamInferenceRequest): The inference request. | |
| Returns: | |
| Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response. | |
| """ | |
| t1 = perf_counter() | |
| if isinstance(request, SamEmbeddingRequest): | |
| embedding, _ = self.embed_image(**request.dict()) | |
| inference_time = perf_counter() - t1 | |
| if request.format == "json": | |
| return SamEmbeddingResponse( | |
| embeddings=embedding.tolist(), time=inference_time | |
| ) | |
| elif request.format == "binary": | |
| binary_vector = BytesIO() | |
| np.save(binary_vector, embedding) | |
| binary_vector.seek(0) | |
| return SamEmbeddingResponse( | |
| embeddings=binary_vector.getvalue(), time=inference_time | |
| ) | |
| elif isinstance(request, SamSegmentationRequest): | |
| masks, low_res_masks = self.segment_image(**request.dict()) | |
| if request.format == "json": | |
| masks = masks > self.predictor.model.mask_threshold | |
| masks = masks2poly(masks) | |
| low_res_masks = low_res_masks > self.predictor.model.mask_threshold | |
| low_res_masks = masks2poly(low_res_masks) | |
| elif request.format == "binary": | |
| binary_vector = BytesIO() | |
| np.savez_compressed( | |
| binary_vector, masks=masks, low_res_masks=low_res_masks | |
| ) | |
| binary_vector.seek(0) | |
| binary_data = binary_vector.getvalue() | |
| return binary_data | |
| else: | |
| raise ValueError(f"Invalid format {request.format}") | |
| response = SamSegmentationResponse( | |
| masks=[m.tolist() for m in masks], | |
| low_res_masks=[m.tolist() for m in low_res_masks], | |
| time=perf_counter() - t1, | |
| ) | |
| return response | |
| def preproc_image(self, image: InferenceRequestImage): | |
| """Preprocesses an image. | |
| Args: | |
| image (InferenceRequestImage): The image to preprocess. | |
| Returns: | |
| np.array: The preprocessed image. | |
| """ | |
| np_image = load_image_rgb(image) | |
| return np_image | |
| def segment_image( | |
| self, | |
| image: Any, | |
| embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None, | |
| embeddings_format: Optional[str] = "json", | |
| has_mask_input: Optional[bool] = False, | |
| image_id: Optional[str] = None, | |
| mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None, | |
| mask_input_format: Optional[str] = "json", | |
| orig_im_size: Optional[List[int]] = None, | |
| point_coords: Optional[List[List[float]]] = [], | |
| point_labels: Optional[List[int]] = [], | |
| use_mask_input_cache: Optional[bool] = True, | |
| **kwargs, | |
| ): | |
| """ | |
| Segments an image based on provided embeddings, points, masks, or cached results. | |
| If embeddings are not directly provided, the function can derive them from the input image or cache. | |
| Args: | |
| image (Any): The image to be segmented. | |
| embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image. | |
| Defaults to None, in which case the image is used to compute embeddings. | |
| embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'. | |
| has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False. | |
| image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks. | |
| mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image. | |
| mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'. | |
| orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly. | |
| point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list. | |
| point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list. | |
| use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True. | |
| **kwargs: Additional keyword arguments. | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image | |
| and the second element is the low resolution segmentation masks. | |
| Raises: | |
| ValueError: If necessary inputs are missing or inconsistent. | |
| Notes: | |
| - Embeddings, segmentations, and low-resolution logits can be cached to improve performance | |
| on repeated requests for the same image. | |
| - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, | |
| the oldest entries are removed. | |
| """ | |
| if not embeddings: | |
| if not image and not image_id: | |
| raise ValueError( | |
| "Must provide either image, cached image_id, or embeddings" | |
| ) | |
| elif image_id and not image and image_id not in self.embedding_cache: | |
| raise ValueError( | |
| f"Image ID {image_id} not in embedding cache, must provide the image or embeddings" | |
| ) | |
| embedding, original_image_size = self.embed_image( | |
| image=image, image_id=image_id | |
| ) | |
| else: | |
| if not orig_im_size: | |
| raise ValueError( | |
| "Must provide original image size if providing embeddings" | |
| ) | |
| original_image_size = orig_im_size | |
| if embeddings_format == "json": | |
| embedding = np.array(embeddings) | |
| elif embeddings_format == "binary": | |
| embedding = np.load(BytesIO(embeddings)) | |
| point_coords = point_coords | |
| point_coords.append([0, 0]) | |
| point_coords = np.array(point_coords, dtype=np.float32) | |
| point_coords = np.expand_dims(point_coords, axis=0) | |
| point_coords = self.predictor.transform.apply_coords( | |
| point_coords, | |
| original_image_size, | |
| ) | |
| point_labels = point_labels | |
| point_labels.append(-1) | |
| point_labels = np.array(point_labels, dtype=np.float32) | |
| point_labels = np.expand_dims(point_labels, axis=0) | |
| if has_mask_input: | |
| if ( | |
| image_id | |
| and image_id in self.low_res_logits_cache | |
| and use_mask_input_cache | |
| ): | |
| mask_input = self.low_res_logits_cache[image_id] | |
| elif not mask_input and ( | |
| not image_id or image_id not in self.low_res_logits_cache | |
| ): | |
| raise ValueError("Must provide either mask_input or cached image_id") | |
| else: | |
| if mask_input_format == "json": | |
| polys = mask_input | |
| mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8) | |
| for i, poly in enumerate(polys): | |
| poly = ShapelyPolygon(poly) | |
| raster = rasterio.features.rasterize( | |
| [poly], out_shape=(256, 256) | |
| ) | |
| mask_input[0, i, :, :] = raster | |
| elif mask_input_format == "binary": | |
| binary_data = base64.b64decode(mask_input) | |
| mask_input = np.load(BytesIO(binary_data)) | |
| else: | |
| mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) | |
| ort_inputs = { | |
| "image_embeddings": embedding.astype(np.float32), | |
| "point_coords": point_coords.astype(np.float32), | |
| "point_labels": point_labels, | |
| "mask_input": mask_input.astype(np.float32), | |
| "has_mask_input": ( | |
| np.zeros(1, dtype=np.float32) | |
| if not has_mask_input | |
| else np.ones(1, dtype=np.float32) | |
| ), | |
| "orig_im_size": np.array(original_image_size, dtype=np.float32), | |
| } | |
| masks, _, low_res_logits = self.ort_session.run(None, ort_inputs) | |
| if image_id: | |
| self.low_res_logits_cache[image_id] = low_res_logits | |
| if image_id not in self.segmentation_cache_keys: | |
| self.segmentation_cache_keys.append(image_id) | |
| if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: | |
| cache_key = self.segmentation_cache_keys.pop(0) | |
| del self.low_res_logits_cache[cache_key] | |
| masks = masks[0] | |
| low_res_masks = low_res_logits[0] | |
| return masks, low_res_masks | |