sam-small / handler.py
Tony Neel
satisfy huggingface init call
305c627
from typing import Dict, List, Any, Union
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import numpy as np
from PIL import Image
import io
import base64
from huggingface_hub import InferenceEndpoint
class EndpointHandler(InferenceEndpoint):
def __init__(self, model_dir=None):
"""Initialize the handler with mock predictor for local testing
Args:
model_dir (str, optional): Path to model directory. Defaults to None.
"""
# Comment out real model for local testing
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
# Mock predictor for local testing
# class MockPredictor:
# def set_image(self, image):
# print(f"Mock: set_image called with shape {image.shape}")
# def predict(self, point_coords=None, point_labels=None):
# print("Mock: predict called")
# if point_coords is not None:
# print(f"Mock: with point coords {point_coords}")
# print(f"Mock: with point labels {point_labels}")
# # Return mock mask focused around the point
# mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(1)]
# mock_scores = np.array([0.95]) # Higher confidence for point prompt
# else:
# # Return multiple mock masks for automatic mode
# mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(3)]
# mock_scores = np.array([0.9, 0.8, 0.7])
# return mock_masks, mock_scores, None
# self.predictor = MockPredictor()
def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
"""Load image from binary or base64 data"""
try:
# Handle base64 encoded data
if isinstance(image_data, str):
image_data = base64.b64decode(image_data)
# Convert bytes to PIL Image
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise ValueError(f"Failed to load image: {str(e)}")
def __call__(self, image_bytes):
# Get point prompts if provided in request
if isinstance(image_bytes, dict):
point_coords = image_bytes.get('point_coords')
point_labels = image_bytes.get('point_labels')
image_bytes = image_bytes['image']
else:
point_coords = None
point_labels = None
# Convert bytes to image
image = Image.open(io.BytesIO(image_bytes))
if image.mode != 'RGB':
image = image.convert('RGB')
image_array = np.array(image)
# Run inference (will use mock predictor locally)
with torch.inference_mode():
if torch.cuda.is_available():
with torch.autocast("cuda", dtype=torch.bfloat16):
self.predictor.set_image(image_array)
masks, scores, _ = self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels
)
else:
self.predictor.set_image(image_array)
masks, scores, _ = self.predictor.predict(
point_coords=point_coords,
point_labels=point_labels
)
# Format output
if masks is not None:
return {
"masks": [mask.tolist() for mask in masks],
"scores": scores.tolist() if scores is not None else None,
"status": "success"
}
return {"error": "No masks generated", "status": "error"}