nicklorch's picture
image payload fix
ff5a99d
raw
history blame
2.2 kB
from io import BytesIO
import base64
import traceback
import logging
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = logging.getLogger(__name__)
logger.setLevel('INFO')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
try:
inputs = data.pop("inputs", None)
text_input = None
image_data = None
if isinstance(inputs, Image.Image):
logger.info('image sent directly')
image = inputs
else:
text_input = inputs["text"] if "text" in inputs else None
image_data = inputs['image'] if 'image' in inputs else None
if image_data is not None:
logger.info('image is encoded')
image = Image.open(BytesIO(base64.b64decode(image_data)))
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]}
elif image:
# image = Image.open(image_data)
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]}
else:
return {'embeddings':None}
except Exception as ex:
logger.error('error doing request: %s', ex)
logger.exception(ex)
stack_info = traceback.format_exc()
logger.error('stack trace:\n%s',stack_info)
return {'Error':stack_info}