rbanfield's picture
Try enabling GPU if available
e62633b
raw
history blame
1.35 kB
from io import BytesIO
import base64
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
inputs = data.pop("inputs", None)
text_input = inputs["text"] if "text" in inputs else None
image_input = inputs["image"] if "image" in inputs else None
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True)
with torch.no_grad():
return self.text_model(**processor).pooler_output.tolist()
elif image_input:
image = Image.open(BytesIO(base64.b64decode(image_input)))
processor = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
return self.image_model(**processor).image_embeds.tolist()
else:
return None