nicklorch commited on
Commit
c6dbef6
·
1 Parent(s): ff5a99d

changes from bobs repo to line up text embeddings

Browse files
Files changed (1) hide show
  1. handler.py +5 -6
handler.py CHANGED
@@ -5,7 +5,7 @@ import logging
5
 
6
  from PIL import Image
7
  import torch
8
- from transformers import CLIPProcessor, CLIPTextModel, CLIPVisionModelWithProjection
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  logger = logging.getLogger(__name__)
@@ -13,8 +13,7 @@ logger.setLevel('INFO')
13
 
14
  class EndpointHandler():
15
  def __init__(self, path=""):
16
- self.text_model = CLIPTextModel.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
17
- self.image_model = CLIPVisionModelWithProjection.from_pretrained("rbanfield/clip-vit-large-patch14").to(device)
18
  self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
19
 
20
  def __call__(self, data):
@@ -23,7 +22,7 @@ class EndpointHandler():
23
  inputs = data.pop("inputs", None)
24
  text_input = None
25
  image_data = None
26
-
27
  if isinstance(inputs, Image.Image):
28
  logger.info('image sent directly')
29
  image = inputs
@@ -38,12 +37,12 @@ class EndpointHandler():
38
  if text_input:
39
  processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
40
  with torch.no_grad():
41
- return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]}
42
  elif image:
43
  # image = Image.open(image_data)
44
  processor = self.processor(images=image, return_tensors="pt").to(device)
45
  with torch.no_grad():
46
- return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]}
47
  else:
48
  return {'embeddings':None}
49
  except Exception as ex:
 
5
 
6
  from PIL import Image
7
  import torch
8
+ from transformers import CLIPProcessor, CLIPModel
9
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  logger = logging.getLogger(__name__)
 
13
 
14
  class EndpointHandler():
15
  def __init__(self, path=""):
16
+ self.model = CLIPModel.from_pretrained("rbanfield/clip-vit-large-patch14").to("cpu")
 
17
  self.processor = CLIPProcessor.from_pretrained("rbanfield/clip-vit-large-patch14")
18
 
19
  def __call__(self, data):
 
22
  inputs = data.pop("inputs", None)
23
  text_input = None
24
  image_data = None
25
+ logger.info('data contents: %s', data)
26
  if isinstance(inputs, Image.Image):
27
  logger.info('image sent directly')
28
  image = inputs
 
37
  if text_input:
38
  processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
39
  with torch.no_grad():
40
+ return {"embeddings": self.model.get_text_features(**processor).tolist()}
41
  elif image:
42
  # image = Image.open(image_data)
43
  processor = self.processor(images=image, return_tensors="pt").to(device)
44
  with torch.no_grad():
45
+ return {"embeddings": self.model.get_image_features(**processor).tolist()}
46
  else:
47
  return {'embeddings':None}
48
  except Exception as ex: