vision
nicklorch's picture
more debugging
884b4bf
raw
history blame
2.2 kB
from io import BytesIO
import base64
import traceback
import logging
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = logging.getLogger(__name__)
logger.setLevel('INFO')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
try:
logger.info('data is %s', data)
text_input = None
if isinstance(data, dict):
print('data is a dict: ', data)
logger.info('data is a dict %s', data)
inputs = data.pop("inputs", None)
text_input = inputs["text"] if "text" in inputs else None
image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None
else:
# assuming its an image sent via binary
image_data = BytesIO(data)
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]}
elif image_data:
image = Image.open(image_data)
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]}
else:
return {'embeddings':None}
except Exception as ex:
logger.error('error doing request: %s', ex)
logger.exception(ex)
stack_info = traceback.format_exc()
logger.error('stack trace:\n%s',stack_info)
return {'Error':stack_info}