rbanfield commited on
Commit
dc1351d
·
1 Parent(s): e1b71b0

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -0
handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import base64
3
+
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ class EndpointHandler():
11
+ def __init__(self, path=""):
12
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to("cpu")
13
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
14
+
15
+ def __call__(self, data):
16
+
17
+ text_input = None
18
+ if isinstance(data, dict):
19
+ inputs = data.pop("inputs", None)
20
+ text_input = inputs.get('text',None)
21
+ image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None
22
+ else:
23
+ # assuming its an image sent via binary
24
+ image_data = BytesIO(data)
25
+
26
+ if text_input:
27
+ processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
28
+ with torch.no_grad():
29
+ return {"embeddings": self.model.get_text_features(**processor).to("cpu").tolist()}
30
+ elif image_data:
31
+ image = Image.open(image_data)
32
+ processor = self.processor(images=image, return_tensors="pt").to(device)
33
+ with torch.no_grad():
34
+ return {"embeddings": self.model.get_image_features(**processor).to("cpu").tolist()}
35
+ else:
36
+ return {"embeddings": None}
37
+