File size: 2,199 Bytes
b926327 a4e73a2 884b4bf b926327 e62633b 884b4bf e62633b b926327 e62633b b926327 a4e73a2 dd31dc4 c04fdf8 a4e73a2 c04fdf8 dd31dc4 ed041c3 c04fdf8 29107e0 c04fdf8 ff5a99d b926327 a4e73a2 29107e0 a4e73a2 884b4bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from io import BytesIO
import base64
import traceback
import logging
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger = logging.getLogger(__name__)
logger.setLevel('INFO')
class EndpointHandler():
def __init__(self, path=""):
self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
def __call__(self, data):
try:
inputs = data.pop("inputs", None)
text_input = None
image_data = None
if isinstance(inputs, Image.Image):
logger.info('image sent directly')
image = inputs
else:
text_input = inputs["text"] if "text" in inputs else None
image_data = inputs['image'] if 'image' in inputs else None
if image_data is not None:
logger.info('image is encoded')
image = Image.open(BytesIO(base64.b64decode(image_data)))
if text_input:
processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]}
elif image:
# image = Image.open(image_data)
processor = self.processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]}
else:
return {'embeddings':None}
except Exception as ex:
logger.error('error doing request: %s', ex)
logger.exception(ex)
stack_info = traceback.format_exc()
logger.error('stack trace:\n%s',stack_info)
return {'Error':stack_info}
|