import torch from PIL import Image import io import base64 import torchvision.transforms as T from model import MedSAM2Model class EndpointHandler: def __init__(self, path=""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = MedSAM2Model().to(self.device) self.model.eval() def preprocess(self, inputs): # Unwrap if "inputs" key exists if "inputs" in inputs: inputs = inputs["inputs"] image_b64 = inputs.get("image") if not image_b64: raise ValueError("Missing 'image' field in input.") image_bytes = base64.b64decode(image_b64) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Transform PIL image to tensor and normalize (example) transform = T.Compose([ T.ToTensor(), # Converts to tensor and scales pixels to [0,1] # Add normalization if your model requires it, e.g.: # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image_tensor = transform(image).unsqueeze(0) # Add batch dim: [1, 3, H, W] return image_tensor.to(self.device) def postprocess(self, output): return {"output": output.cpu().tolist()} def __call__(self, inputs): x = self.preprocess(inputs) with torch.no_grad(): output = self.model(x) return self.postprocess(output)