|
import torch |
|
from PIL import Image |
|
import io |
|
import base64 |
|
import torchvision.transforms as T |
|
|
|
from model import MedSAM2Model |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = MedSAM2Model().to(self.device) |
|
self.model.eval() |
|
|
|
|
|
def preprocess(self, inputs): |
|
|
|
if "inputs" in inputs: |
|
inputs = inputs["inputs"] |
|
|
|
image_b64 = inputs.get("image") |
|
if not image_b64: |
|
raise ValueError("Missing 'image' field in input.") |
|
|
|
image_bytes = base64.b64decode(image_b64) |
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
transform = T.Compose([ |
|
T.ToTensor(), |
|
|
|
|
|
]) |
|
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
return image_tensor.to(self.device) |
|
|
|
|
|
def postprocess(self, output): |
|
return {"output": output.cpu().tolist()} |
|
|
|
def __call__(self, inputs): |
|
x = self.preprocess(inputs) |
|
with torch.no_grad(): |
|
output = self.model(x) |
|
return self.postprocess(output) |
|
|