srinath-wp commited on
Commit
588bce7
·
verified ·
1 Parent(s): b4e0815

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +14 -4
  2. requirements.txt +1 -0
handler.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  from PIL import Image
3
  import io
4
  import base64
 
5
 
6
  from model import MedSAM2Model
7
 
@@ -11,7 +12,12 @@ class EndpointHandler:
11
  self.model = MedSAM2Model().to(self.device)
12
  self.model.eval()
13
 
 
14
  def preprocess(self, inputs):
 
 
 
 
15
  image_b64 = inputs.get("image")
16
  if not image_b64:
17
  raise ValueError("Missing 'image' field in input.")
@@ -19,10 +25,14 @@ class EndpointHandler:
19
  image_bytes = base64.b64decode(image_b64)
20
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
21
 
22
- # Convert PIL Image to PyTorch tensor (C x H x W), normalized [0,1]
23
- image_tensor = torch.from_numpy(
24
- (np.array(image) / 255.0).astype('float32')
25
- ).permute(2, 0, 1).unsqueeze(0) # Add batch dimension
 
 
 
 
26
 
27
  return image_tensor.to(self.device)
28
 
 
2
  from PIL import Image
3
  import io
4
  import base64
5
+ import torchvision.transforms as T
6
 
7
  from model import MedSAM2Model
8
 
 
12
  self.model = MedSAM2Model().to(self.device)
13
  self.model.eval()
14
 
15
+
16
  def preprocess(self, inputs):
17
+ # Unwrap if "inputs" key exists
18
+ if "inputs" in inputs:
19
+ inputs = inputs["inputs"]
20
+
21
  image_b64 = inputs.get("image")
22
  if not image_b64:
23
  raise ValueError("Missing 'image' field in input.")
 
25
  image_bytes = base64.b64decode(image_b64)
26
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
27
 
28
+ # Transform PIL image to tensor and normalize (example)
29
+ transform = T.Compose([
30
+ T.ToTensor(), # Converts to tensor and scales pixels to [0,1]
31
+ # Add normalization if your model requires it, e.g.:
32
+ # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ image_tensor = transform(image).unsqueeze(0) # Add batch dim: [1, 3, H, W]
36
 
37
  return image_tensor.to(self.device)
38
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
 
2
  Pillow
3
  huggingface_hub
 
1
  torch
2
+ torchvision
3
  Pillow
4
  huggingface_hub