|
from typing import Any, Dict, List, AnyStr |
|
import numpy as np |
|
from transformers import CLIPProcessor, CLIPModel |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="") -> None: |
|
"Preload all the elements we need at inference." |
|
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
self.path = path |
|
|
|
def __call__(self, data: Dict[str, AnyStr]) -> List[Dict[str, AnyStr]]: |
|
"Run the inference." |
|
inputs = data.get('inputs') |
|
text = inputs.get('text') |
|
imageData = inputs.get('image') |
|
image = Image.open(BytesIO(base64.b64decode(imageData))) |
|
inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True) |
|
outputs = self.model(**inputs) |
|
image_embeds = outputs.image_embeds.detach().numpy().flatten().tolist() |
|
text_embeds = outputs.text_embeds.detach().numpy().flatten().tolist() |
|
logits_per_image = outputs.logits_per_image.detach().numpy().flatten().tolist() |
|
return {'image_embeddings': image_embeds, 'text_embeddings': text_embeds, 'logits_per_image': logits_per_image} |
|
|