rayh commited on
Commit
543627a
·
verified ·
1 Parent(s): 3777e37

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +9 -25
model.py CHANGED
@@ -1,29 +1,13 @@
1
- import numpy as np
2
  import onnxruntime as ort
3
- import torch
4
- from huggingface_hub import hf_hub_download
5
  from PIL import Image
6
 
7
-
8
  class YOLOSegmentationModel:
9
- def __init__(self):
10
- # Download and load the ONNX model from Hugging Face Hub
11
- model_path = hf_hub_download(repo_id="rayh/astro-seg", filename="astro-yolo11m-seg.onnx")
12
- self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
13
-
14
- def preprocess(self, image: Image.Image):
15
- # Convert image to RGB and preprocess for ONNX model
16
- input_array = np.array(image.convert("RGB")).astype(np.float32)
17
- input_array = np.expand_dims(input_array, axis=0) # Add batch dimension
18
- return input_array
19
-
20
- def predict(self, image: Image.Image):
21
- input_tensor = self.preprocess(image)
22
- outputs = self.session.run(None, {"images": input_tensor})
23
- return outputs # Modify if needed to return bounding boxes/masks
24
-
25
- model = YOLOSegmentationModel()
26
-
27
- # HF Inference API expects a `predict` function
28
- def predict(image: Image.Image):
29
- return model.predict(image)
 
 
1
  import onnxruntime as ort
2
+ import numpy as np
 
3
  from PIL import Image
4
 
 
5
  class YOLOSegmentationModel:
6
+ def __init__(self, model_path: str):
7
+ self.session = ort.InferenceSession(model_path)
8
+
9
+ def predict(self, image: Image):
10
+ input_data = np.array(image).astype(np.float32)
11
+ input_data = np.expand_dims(input_data, axis=0) # Add batch dimension
12
+ outputs = self.session.run(None, {"images": input_data})
13
+ return outputs