|
from typing import Dict, List, Any, Union |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import io |
|
import base64 |
|
from huggingface_hub import InferenceEndpoint |
|
|
|
class EndpointHandler(InferenceEndpoint): |
|
def __init__(self, model_dir=None): |
|
"""Initialize the handler with mock predictor for local testing |
|
|
|
Args: |
|
model_dir (str, optional): Path to model directory. Defaults to None. |
|
""" |
|
|
|
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_image(self, image_data: Union[str, bytes]) -> Image.Image: |
|
"""Load image from binary or base64 data""" |
|
try: |
|
|
|
if isinstance(image_data, str): |
|
image_data = base64.b64decode(image_data) |
|
|
|
|
|
image = Image.open(io.BytesIO(image_data)) |
|
return image |
|
except Exception as e: |
|
raise ValueError(f"Failed to load image: {str(e)}") |
|
|
|
def __call__(self, image_bytes): |
|
|
|
if isinstance(image_bytes, dict): |
|
point_coords = image_bytes.get('point_coords') |
|
point_labels = image_bytes.get('point_labels') |
|
image_bytes = image_bytes['image'] |
|
else: |
|
point_coords = None |
|
point_labels = None |
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
image_array = np.array(image) |
|
|
|
|
|
with torch.inference_mode(): |
|
if torch.cuda.is_available(): |
|
with torch.autocast("cuda", dtype=torch.bfloat16): |
|
self.predictor.set_image(image_array) |
|
masks, scores, _ = self.predictor.predict( |
|
point_coords=point_coords, |
|
point_labels=point_labels |
|
) |
|
else: |
|
self.predictor.set_image(image_array) |
|
masks, scores, _ = self.predictor.predict( |
|
point_coords=point_coords, |
|
point_labels=point_labels |
|
) |
|
|
|
|
|
if masks is not None: |
|
return { |
|
"masks": [mask.tolist() for mask in masks], |
|
"scores": scores.tolist() if scores is not None else None, |
|
"status": "success" |
|
} |
|
return {"error": "No masks generated", "status": "error"} |