rbanfield's picture
Another GPU test
4dbb20c
raw
history blame
1.37 kB
from io import BytesIO
import base64
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
inputs = data.pop("inputs", None)
text_input = inputs["text"] if "text" in inputs else None
image_input = inputs["image"] if "image" in inputs else None
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return self.text_model(**processor).pooler_output.tolist()
elif image_input:
image = Image.open(BytesIO(base64.b64decode(image_input)))
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return self.image_model(**processor).image_embeds.tolist()
else:
return None