dazpye commited on
Commit
fe492f5
·
verified ·
1 Parent(s): 4280311

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -20
handler.py CHANGED
@@ -1,32 +1,44 @@
1
  import torch
2
  from transformers import CLIPProcessor, CLIPModel
 
 
 
3
 
4
  class EndpointHandler:
5
- def __init__(self):
6
- # Load model and processor
7
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
8
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
9
 
10
- def preprocess(self, inputs):
11
- # Process input data
12
- text = inputs.get("text", [])
13
- images = inputs.get("images", [])
14
- return self.processor(text=text, images=images, return_tensors="pt")
 
 
 
15
 
16
- def inference(self, inputs):
17
- # Run inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  with torch.no_grad():
19
  outputs = self.model(**inputs)
20
- return outputs.logits_per_image.tolist()
21
-
22
- def postprocess(self, inference_output):
23
- # Convert output to readable format
24
- return {"predictions": inference_output}
25
 
26
- handler = EndpointHandler()
 
27
 
28
- def handle(request):
29
- inputs = request if isinstance(request, dict) else request.json()
30
- processed_inputs = handler.preprocess(inputs)
31
- predictions = handler.inference(processed_inputs)
32
- return handler.postprocess(predictions)
 
1
  import torch
2
  from transformers import CLIPProcessor, CLIPModel
3
+ from PIL import Image
4
+ import base64
5
+ import io
6
 
7
  class EndpointHandler:
8
+ def __init__(self, model_dir=None): # AWS expects model_dir
9
+ print("Loading model...")
10
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
11
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
12
 
13
+ def _load_image(self, image_data):
14
+ """Handles both URL-based and base64 image inputs."""
15
+ if isinstance(image_data, str):
16
+ if image_data.startswith("http"):
17
+ return Image.open(requests.get(image_data, stream=True).raw)
18
+ else: # Assume base64-encoded image
19
+ return Image.open(io.BytesIO(base64.b64decode(image_data)))
20
+ return None # Invalid image format
21
 
22
+ def __call__(self, data):
23
+ """Main inference function Hugging Face expects."""
24
+ print("Processing input...")
25
+
26
+ text = data.get("text", ["default caption"]) # Default text
27
+ images = data.get("images", []) # List of images
28
+
29
+ # Convert image URLs or base64 strings to PIL images
30
+ pil_images = [self._load_image(img) for img in images if img]
31
+
32
+ if not pil_images:
33
+ return {"error": "No valid images provided."}
34
+
35
+ inputs = self.processor(text=text, images=pil_images, return_tensors="pt")
36
+
37
+ print("Running inference...")
38
  with torch.no_grad():
39
  outputs = self.model(**inputs)
 
 
 
 
 
40
 
41
+ logits_per_image = outputs.logits_per_image
42
+ probabilities = logits_per_image.softmax(dim=1)
43
 
44
+ return {"predictions": probabilities.tolist()}