nicklorch commited on
Commit
29107e0
·
1 Parent(s): 884b4bf

figured out image issue

Browse files
Files changed (1) hide show
  1. handler.py +14 -11
handler.py CHANGED
@@ -22,23 +22,26 @@ class EndpointHandler():
22
 
23
  logger.info('data is %s', data)
24
  text_input = None
25
- if isinstance(data, dict):
26
- print('data is a dict: ', data)
27
- logger.info('data is a dict %s', data)
28
- inputs = data.pop("inputs", None)
29
- text_input = inputs["text"] if "text" in inputs else None
30
- image_data = BytesIO(base64.b64decode(inputs['image'])) if 'image' in inputs else None
31
- else:
32
- # assuming its an image sent via binary
33
- image_data = BytesIO(data)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  if text_input:
37
  processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
38
  with torch.no_grad():
39
  return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]}
40
- elif image_data:
41
- image = Image.open(image_data)
42
  processor = self.processor(images=image, return_tensors="pt").to(device)
43
  with torch.no_grad():
44
  return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]}
 
22
 
23
  logger.info('data is %s', data)
24
  text_input = None
 
 
 
 
 
 
 
 
 
25
 
26
+ logger.info('data is a dict %s', data)
27
+ inputs = data.pop("inputs", None)
28
+ text_input = inputs["text"] if "text" in inputs else None
29
+ image_data = inputs['image'] if 'image' in inputs else None
30
+
31
+ if image_data is not None:
32
+ if isinstance(image_data, Image):
33
+ logger.info('image is an image')
34
+ image = image_data
35
+ else:
36
+ logger.info('image is encoded')
37
+ image = BytesIO(base64.b64decode(image_data))
38
 
39
  if text_input:
40
  processor = self.processor(text=text_input, return_tensors="pt", padding=True).to(device)
41
  with torch.no_grad():
42
  return {'embeddings':self.text_model(**processor).pooler_output.tolist()[0]}
43
+ elif image:
44
+ # image = Image.open(image_data)
45
  processor = self.processor(images=image, return_tensors="pt").to(device)
46
  with torch.no_grad():
47
  return {'embeddings':self.image_model(**processor).image_embeds.tolist()[0]}